mirror of https://github.com/ctz/rustls
293 lines
8.5 KiB
Rust
293 lines
8.5 KiB
Rust
use std::fs::{self, File};
|
|
use std::str;
|
|
use tempfile;
|
|
|
|
use std::sync::Arc;
|
|
use std::io::{self, Write};
|
|
|
|
use rustls;
|
|
|
|
use rustls::{ClientConfig, ClientSession};
|
|
use rustls::{ServerConfig, ServerSession};
|
|
use rustls::Session;
|
|
use rustls::ProtocolVersion;
|
|
use rustls::TLSError;
|
|
use rustls::{Certificate, PrivateKey};
|
|
use rustls::internal::pemfile;
|
|
use rustls::{RootCertStore, NoClientAuth, AllowAnyAuthenticatedClient};
|
|
|
|
use webpki;
|
|
|
|
macro_rules! embed_files {
|
|
(
|
|
$(
|
|
($name:ident, $keytype:expr, $path:expr);
|
|
)+
|
|
) => {
|
|
$(
|
|
const $name: &'static [u8] = include_bytes!(
|
|
concat!("../../test-ca/", $keytype, "/", $path));
|
|
)+
|
|
|
|
pub fn bytes_for(keytype: &str, path: &str) -> &'static [u8] {
|
|
match (keytype, path) {
|
|
$(
|
|
($keytype, $path) => $name,
|
|
)+
|
|
_ => panic!("unknown keytype {} with path {}", keytype, path),
|
|
}
|
|
}
|
|
|
|
pub fn new_test_ca() -> tempfile::TempDir {
|
|
let dir = tempfile::TempDir::new().unwrap();
|
|
|
|
fs::create_dir(dir.path().join("ecdsa")).unwrap();
|
|
fs::create_dir(dir.path().join("rsa")).unwrap();
|
|
|
|
$(
|
|
let mut f = File::create(dir.path().join($keytype).join($path)).unwrap();
|
|
f.write($name).unwrap();
|
|
)+
|
|
|
|
dir
|
|
}
|
|
}
|
|
}
|
|
|
|
embed_files! {
|
|
(ECDSA_CA_CERT, "ecdsa", "ca.cert");
|
|
(ECDSA_CA_DER, "ecdsa", "ca.der");
|
|
(ECDSA_CA_KEY, "ecdsa", "ca.key");
|
|
(ECDSA_CLIENT_CERT, "ecdsa", "client.cert");
|
|
(ECDSA_CLIENT_CHAIN, "ecdsa", "client.chain");
|
|
(ECDSA_CLIENT_FULLCHAIN, "ecdsa", "client.fullchain");
|
|
(ECDSA_CLIENT_KEY, "ecdsa", "client.key");
|
|
(ECDSA_CLIENT_REQ, "ecdsa", "client.req");
|
|
(ECDSA_END_CERT, "ecdsa", "end.cert");
|
|
(ECDSA_END_CHAIN, "ecdsa", "end.chain");
|
|
(ECDSA_END_FULLCHAIN, "ecdsa", "end.fullchain");
|
|
(ECDSA_END_KEY, "ecdsa", "end.key");
|
|
(ECDSA_END_REQ, "ecdsa", "end.req");
|
|
(ECDSA_INTER_CERT, "ecdsa", "inter.cert");
|
|
(ECDSA_INTER_KEY, "ecdsa", "inter.key");
|
|
(ECDSA_INTER_REQ, "ecdsa", "inter.req");
|
|
(ECDSA_NISTP256_PEM, "ecdsa", "nistp256.pem");
|
|
(ECDSA_NISTP384_PEM, "ecdsa", "nistp384.pem");
|
|
|
|
(RSA_CA_CERT, "rsa", "ca.cert");
|
|
(RSA_CA_DER, "rsa", "ca.der");
|
|
(RSA_CA_KEY, "rsa", "ca.key");
|
|
(RSA_CLIENT_CERT, "rsa", "client.cert");
|
|
(RSA_CLIENT_CHAIN, "rsa", "client.chain");
|
|
(RSA_CLIENT_FULLCHAIN, "rsa", "client.fullchain");
|
|
(RSA_CLIENT_KEY, "rsa", "client.key");
|
|
(RSA_CLIENT_REQ, "rsa", "client.req");
|
|
(RSA_CLIENT_RSA, "rsa", "client.rsa");
|
|
(RSA_END_CERT, "rsa", "end.cert");
|
|
(RSA_END_CHAIN, "rsa", "end.chain");
|
|
(RSA_END_FULLCHAIN, "rsa", "end.fullchain");
|
|
(RSA_END_KEY, "rsa", "end.key");
|
|
(RSA_END_REQ, "rsa", "end.req");
|
|
(RSA_END_RSA, "rsa", "end.rsa");
|
|
(RSA_INTER_CERT, "rsa", "inter.cert");
|
|
(RSA_INTER_KEY, "rsa", "inter.key");
|
|
(RSA_INTER_REQ, "rsa", "inter.req");
|
|
}
|
|
|
|
pub fn transfer(left: &mut dyn Session, right: &mut dyn Session) -> usize {
|
|
let mut buf = [0u8; 262144];
|
|
let mut total = 0;
|
|
|
|
while left.wants_write() {
|
|
let sz = left.write_tls(&mut buf.as_mut()).unwrap();
|
|
total += sz;
|
|
if sz == 0 {
|
|
return total;
|
|
}
|
|
|
|
let mut offs = 0;
|
|
loop {
|
|
offs += right.read_tls(&mut buf[offs..sz].as_ref()).unwrap();
|
|
if sz == offs {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
total
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
pub enum KeyType {
|
|
RSA,
|
|
ECDSA
|
|
}
|
|
|
|
pub static ALL_KEY_TYPES: [KeyType; 2] = [ KeyType::RSA, KeyType::ECDSA ];
|
|
|
|
impl KeyType {
|
|
fn bytes_for(&self, part: &str) -> &'static [u8] {
|
|
match self {
|
|
KeyType::RSA => bytes_for("rsa", part),
|
|
KeyType::ECDSA => bytes_for("ecdsa", part),
|
|
}
|
|
}
|
|
|
|
pub fn get_chain(&self) -> Vec<Certificate> {
|
|
pemfile::certs(&mut io::BufReader::new(self.bytes_for("end.fullchain")))
|
|
.unwrap()
|
|
}
|
|
|
|
pub fn get_key(&self) -> PrivateKey {
|
|
pemfile::pkcs8_private_keys(&mut io::BufReader::new(self.bytes_for("end.key")))
|
|
.unwrap()[0]
|
|
.clone()
|
|
}
|
|
|
|
fn get_client_chain(&self) -> Vec<Certificate> {
|
|
pemfile::certs(&mut io::BufReader::new(self.bytes_for("client.fullchain")))
|
|
.unwrap()
|
|
}
|
|
|
|
fn get_client_key(&self) -> PrivateKey {
|
|
pemfile::pkcs8_private_keys(&mut io::BufReader::new(self.bytes_for("client.key")))
|
|
.unwrap()[0]
|
|
.clone()
|
|
}
|
|
}
|
|
|
|
pub fn make_server_config(kt: KeyType) -> ServerConfig {
|
|
let mut cfg = ServerConfig::new(NoClientAuth::new());
|
|
cfg.set_single_cert(kt.get_chain(), kt.get_key()).unwrap();
|
|
|
|
cfg
|
|
}
|
|
|
|
pub fn make_server_config_with_mandatory_client_auth(kt: KeyType) -> ServerConfig {
|
|
let roots = kt.get_chain();
|
|
let mut client_auth_roots = RootCertStore::empty();
|
|
for root in roots {
|
|
client_auth_roots.add(&root).unwrap();
|
|
}
|
|
|
|
let client_auth = AllowAnyAuthenticatedClient::new(client_auth_roots);
|
|
let mut cfg = ServerConfig::new(client_auth);
|
|
cfg.set_single_cert(kt.get_chain(), kt.get_key()).unwrap();
|
|
|
|
cfg
|
|
}
|
|
|
|
pub fn make_client_config(kt: KeyType) -> ClientConfig {
|
|
let mut cfg = ClientConfig::new();
|
|
let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert"));
|
|
cfg.root_store.add_pem_file(&mut rootbuf).unwrap();
|
|
|
|
cfg
|
|
}
|
|
|
|
pub fn make_client_config_with_auth(kt: KeyType) -> ClientConfig {
|
|
let mut cfg = make_client_config(kt);
|
|
cfg.set_single_client_cert(kt.get_client_chain(), kt.get_client_key());
|
|
cfg
|
|
}
|
|
|
|
pub fn make_pair(kt: KeyType) -> (ClientSession, ServerSession) {
|
|
make_pair_for_configs(make_client_config(kt),
|
|
make_server_config(kt))
|
|
}
|
|
|
|
pub fn make_pair_for_configs(client_config: ClientConfig,
|
|
server_config: ServerConfig) -> (ClientSession, ServerSession) {
|
|
make_pair_for_arc_configs(&Arc::new(client_config),
|
|
&Arc::new(server_config))
|
|
}
|
|
|
|
pub fn make_pair_for_arc_configs(client_config: &Arc<ClientConfig>,
|
|
server_config: &Arc<ServerConfig>) -> (ClientSession, ServerSession) {
|
|
(
|
|
ClientSession::new(client_config, dns_name("localhost")),
|
|
ServerSession::new(server_config)
|
|
)
|
|
}
|
|
|
|
pub fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) -> (usize, usize) {
|
|
let (mut to_client, mut to_server) = (0, 0);
|
|
while server.is_handshaking() || client.is_handshaking() {
|
|
to_server += transfer(client, server);
|
|
server.process_new_packets().unwrap();
|
|
to_client += transfer(server, client);
|
|
client.process_new_packets().unwrap();
|
|
}
|
|
(to_server, to_client)
|
|
}
|
|
|
|
pub struct AllClientVersions {
|
|
client_config: ClientConfig,
|
|
index: usize,
|
|
}
|
|
|
|
impl AllClientVersions {
|
|
pub fn new(client_config: ClientConfig) -> AllClientVersions {
|
|
AllClientVersions { client_config, index: 0 }
|
|
}
|
|
}
|
|
|
|
impl Iterator for AllClientVersions {
|
|
type Item = ClientConfig;
|
|
|
|
fn next(&mut self) -> Option<ClientConfig> {
|
|
let mut config = self.client_config.clone();
|
|
self.index += 1;
|
|
|
|
match self.index {
|
|
1 => {
|
|
config.versions = vec![ProtocolVersion::TLSv1_2];
|
|
Some(config)
|
|
},
|
|
2 => {
|
|
config.versions = vec![ProtocolVersion::TLSv1_3];
|
|
Some(config)
|
|
},
|
|
_ => None
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(PartialEq, Debug)]
|
|
pub enum TLSErrorFromPeer { Client(TLSError), Server(TLSError) }
|
|
|
|
pub fn do_handshake_until_error(client: &mut ClientSession,
|
|
server: &mut ServerSession)
|
|
-> Result<(), TLSErrorFromPeer> {
|
|
while server.is_handshaking() || client.is_handshaking() {
|
|
transfer(client, server);
|
|
server.process_new_packets()
|
|
.map_err(|err| TLSErrorFromPeer::Server(err))?;
|
|
transfer(server, client);
|
|
client.process_new_packets()
|
|
.map_err(|err| TLSErrorFromPeer::Client(err))?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn dns_name(name: &'static str) -> webpki::DNSNameRef<'_> {
|
|
webpki::DNSNameRef::try_from_ascii_str(name).unwrap()
|
|
}
|
|
|
|
pub struct FailsReads {
|
|
errkind: io::ErrorKind
|
|
}
|
|
|
|
impl FailsReads {
|
|
pub fn new(errkind: io::ErrorKind) -> FailsReads {
|
|
FailsReads { errkind }
|
|
}
|
|
}
|
|
|
|
impl io::Read for FailsReads {
|
|
fn read(&mut self, _b: &mut [u8]) -> io::Result<usize> {
|
|
Err(io::Error::from(self.errkind))
|
|
}
|
|
}
|