Implement ALPN

Testing this is relatively annoying because OpenSSL
don't do error handling as specified by the RFC,
and because the RFC contradicts itself.  Quality
all round.
This commit is contained in:
Joseph Birr-Pixton 2016-06-05 21:18:46 +01:00
parent ffd183c202
commit 17e1670593
7 changed files with 174 additions and 57 deletions

View File

@ -90,8 +90,9 @@ Connection closed
```
# TODO list
- [ ] Improve testing.
- [ ] ALPN.
- [x] Improve testing.
- [ ] Improve testing some more.
- [x] ALPN.
- [ ] Tickets.
- [ ] Resumption.
- [ ] chacha20poly1305 bulk encryption support.

View File

@ -16,6 +16,7 @@ extern crate docopt;
use docopt::Docopt;
extern crate rustls;
use rustls::client::{ClientConfig, ClientSession};
const CLIENT: mio::Token = mio::Token(0);
@ -23,7 +24,7 @@ struct TlsClient {
socket: TcpStream,
closing: bool,
clean_closure: bool,
tls_session: rustls::client::ClientSession
tls_session: ClientSession
}
impl mio::Handler for TlsClient {
@ -76,48 +77,8 @@ impl io::Read for TlsClient {
}
}
fn find_suite(name: &str) -> Option<&'static rustls::suites::SupportedCipherSuite> {
for suite in rustls::suites::DEFAULT_CIPHERSUITES.iter() {
let sname = format!("{:?}", suite.suite).to_lowercase();
if sname == name.to_string().to_lowercase() {
return Some(suite);
}
}
None
}
fn lookup_suites(suites: &Vec<String>) -> Vec<&'static rustls::suites::SupportedCipherSuite> {
let mut out = Vec::new();
for csname in suites {
let scs = find_suite(csname);
match scs {
Some(s) => out.push(s),
None => panic!("cannot look up ciphersuite '{}'", csname)
}
}
out
}
impl TlsClient {
fn new(sock: TcpStream, hostname: &str, cafile: &str, suites: &Vec<String>) -> TlsClient {
let mut config = rustls::client::ClientConfig::default();
if suites.len() != 0 {
config.ciphersuites = lookup_suites(suites);
}
let certfile = std::fs::File::open(cafile)
.unwrap();
let mut reader = BufReader::new(certfile);
config.root_store.add_pem_file(&mut reader)
.unwrap();
let cfg = Arc::new(config);
fn new(sock: TcpStream, hostname: &str, cfg: Arc<rustls::client::ClientConfig>) -> TlsClient {
TlsClient {
socket: sock,
closing: false,
@ -216,7 +177,7 @@ before making the connection. --http replaces this with a
basic HTTP GET request for /.
Usage:
tlsclient [--verbose] [-p PORT] [--http] [--cafile CAFILE] [--suite SUITE...] <hostname>
tlsclient [--verbose] [-p PORT] [--http] [--cafile CAFILE] [--suite SUITE...] [--proto PROTOCOL...] <hostname>
tlsclient --version
tlsclient --help
@ -226,6 +187,7 @@ Options:
--cafile CAFILE Read root certificates from CAFILE.
--suite SUITE Disable default cipher suite list, and use
SUITE instead.
--proto PROTOCOL Send ALPN extension containing PROTOCOL.
--verbose Emit log output.
--version Show tool version.
--help Show this screen.
@ -237,6 +199,7 @@ struct Args {
flag_http: bool,
flag_verbose: bool,
flag_suite: Vec<String>,
flag_proto: Vec<String>,
flag_cafile: Option<String>,
arg_hostname: String
}
@ -254,6 +217,54 @@ fn lookup_ipv4(host: &str, port: u16) -> SocketAddr {
unreachable!("Cannot lookup address");
}
fn find_suite(name: &str) -> Option<&'static rustls::suites::SupportedCipherSuite> {
for suite in rustls::suites::DEFAULT_CIPHERSUITES.iter() {
let sname = format!("{:?}", suite.suite).to_lowercase();
if sname == name.to_string().to_lowercase() {
return Some(suite);
}
}
None
}
fn lookup_suites(suites: &Vec<String>) -> Vec<&'static rustls::suites::SupportedCipherSuite> {
let mut out = Vec::new();
for csname in suites {
let scs = find_suite(csname);
match scs {
Some(s) => out.push(s),
None => panic!("cannot look up ciphersuite '{}'", csname)
}
}
out
}
fn make_config(args: &Args) -> Arc<ClientConfig> {
let mut config = ClientConfig::default();
if args.flag_suite.len() != 0 {
config.ciphersuites = lookup_suites(&args.flag_suite);
}
let cafile = match args.flag_cafile {
Some(ref cafile) => cafile.clone(),
None => "/etc/ssl/certs/ca-certificates.crt".to_string()
};
let certfile = std::fs::File::open(cafile)
.unwrap();
let mut reader = BufReader::new(certfile);
config.root_store.add_pem_file(&mut reader)
.unwrap();
config.set_protocols(&args.flag_proto);
Arc::new(config)
}
fn main() {
let version = env!("CARGO_PKG_NAME").to_string() + ", version: " + env!("CARGO_PKG_VERSION");
@ -271,11 +282,11 @@ fn main() {
let port = args.flag_port.unwrap_or(443);
let addr = lookup_ipv4(args.arg_hostname.as_str(), port);
let cafile = args.flag_cafile.unwrap_or("/etc/ssl/certs/ca-certificates.crt".to_string());
let config = make_config(&args);
let sock = TcpStream::connect(&addr).unwrap();
let mut tlsclient = TlsClient::new(sock, &args.arg_hostname, &cafile, &args.flag_suite);
let mut tlsclient = TlsClient::new(sock, &args.arg_hostname, config);
if args.flag_http {
let httpreq = format!("GET / HTTP/1.1\r\nHost: {}\r\nConnection: close\r\nAccept-Encoding: identity\r\n\r\n", args.arg_hostname);
@ -284,7 +295,7 @@ fn main() {
let mut stdin = io::stdin();
tlsclient.read_source_to_end(&mut stdin).unwrap();
}
let mut event_loop = mio::EventLoop::new().unwrap();
tlsclient.register(&mut event_loop);
event_loop.run(&mut tlsclient).unwrap();

View File

@ -20,21 +20,42 @@ use std::io;
use std::collections::VecDeque;
use std::mem;
/// Common configuration for all connections made by
/// a program.
///
/// Making one of these can be expensive, and should be
/// once per process rather than once per connection.
pub struct ClientConfig {
/* List of ciphersuites, in preference order. */
/// List of ciphersuites, in preference order.
pub ciphersuites: Vec<&'static SupportedCipherSuite>,
/* Collection of root certificates. */
pub root_store: verify::RootCertStore
/// Collection of root certificates.
pub root_store: verify::RootCertStore,
/// Which ALPN protocols we include in our client hello.
/// If empty, no ALPN extension is sent.
pub alpn_protocols: Vec<String>
}
impl ClientConfig {
/// Make a `ClientConfig` with a default set of ciphersuites,
/// no root certificates, and no ALPN protocols.
pub fn default() -> ClientConfig {
ClientConfig {
ciphersuites: DEFAULT_CIPHERSUITES.to_vec(),
root_store: verify::RootCertStore::empty()
root_store: verify::RootCertStore::empty(),
alpn_protocols: Vec::new()
}
}
/// Set the ALPN protocol list to the given protocol names.
/// Overwrites any existing configured protocols.
/// The first element in the `protocols` list is the most
/// preferred, the last is the least preferred.
pub fn set_protocols(&mut self, protocols: &[String]) {
self.alpn_protocols.clear();
self.alpn_protocols.extend_from_slice(protocols);
}
}
pub struct ClientHandshakeData {
@ -104,6 +125,7 @@ pub struct ClientSession {
write_seq: u64,
read_seq: u64,
peer_eof: bool,
pub alpn_protocol: Option<String>,
pub message_deframer: MessageDeframer,
pub handshake_joiner: HandshakeJoiner,
pub message_fragmenter: MessageFragmenter,
@ -124,6 +146,7 @@ impl ClientSession {
write_seq: 0,
read_seq: 0,
peer_eof: false,
alpn_protocol: None,
message_deframer: MessageDeframer::new(),
handshake_joiner: HandshakeJoiner::new(),
message_fragmenter: MessageFragmenter::new(MAX_FRAGMENT_LEN),

View File

@ -8,6 +8,7 @@ use msgs::handshake::ClientExtension;
use msgs::handshake::{SupportedSignatureAlgorithms, SupportedMandatedSignatureAlgorithms};
use msgs::handshake::{EllipticCurveList, SupportedCurves};
use msgs::handshake::{ECPointFormatList, SupportedPointFormats};
use msgs::handshake::{ProtocolNameList, ConvertProtocolNameList};
use msgs::handshake::ServerKeyExchangePayload;
use msgs::ccs::ChangeCipherSpecPayload;
use client::{ClientSession, ConnState};
@ -47,6 +48,10 @@ pub fn emit_client_hello(sess: &mut ClientSession) {
exts.push(ClientExtension::EllipticCurves(EllipticCurveList::supported()));
exts.push(ClientExtension::SignatureAlgorithms(SupportedSignatureAlgorithms::supported()));
if sess.config.alpn_protocols.len() > 0 {
exts.push(ClientExtension::Protocols(ProtocolNameList::convert(&sess.config.alpn_protocols)));
}
let sh = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
@ -67,6 +72,8 @@ pub fn emit_client_hello(sess: &mut ClientSession) {
)
};
debug!("Sending ClientHello {:?}", sh);
sh.payload.encode(&mut sess.handshake_data.client_hello);
sess.tls_queue.push_back(sh);
}
@ -80,6 +87,7 @@ fn expect_server_hello() -> Expectation {
fn handle_server_hello(sess: &mut ClientSession, m: &Message) -> Result<ConnState, HandshakeError> {
let server_hello = extract_handshake!(m, HandshakePayload::ServerHello).unwrap();
debug!("We got ServerHello {:?}", server_hello);
if server_hello.server_version != ProtocolVersion::TLSv1_2 {
return Err(HandshakeError::General("server does not support TLSv1_2".to_string()));
@ -89,6 +97,10 @@ fn handle_server_hello(sess: &mut ClientSession, m: &Message) -> Result<ConnStat
return Err(HandshakeError::General("server chose non-Null compression".to_string()));
}
/* Extract ALPN protocol */
sess.alpn_protocol = server_hello.get_alpn_protocol();
info!("ALPN protocol is {:?}", sess.alpn_protocol);
let scs = sess.find_cipher_suite(&server_hello.cipher_suite);
if scs.is_none() {

View File

@ -548,6 +548,7 @@ pub enum ExtensionType {
SignatureAlgorithms,
UseSRTP,
Heartbeat,
ALProtocolNegotiation,
Padding,
SessionTicket,
NextProtocolNegotiation,
@ -585,6 +586,7 @@ impl Codec for ExtensionType {
0x000d => ExtensionType::SignatureAlgorithms,
0x000e => ExtensionType::UseSRTP,
0x000f => ExtensionType::Heartbeat,
0x0010 => ExtensionType::ALProtocolNegotiation,
0x0015 => ExtensionType::Padding,
0x0023 => ExtensionType::SessionTicket,
0x3374 => ExtensionType::NextProtocolNegotiation,
@ -614,6 +616,7 @@ impl ExtensionType {
ExtensionType::SignatureAlgorithms => 0x000d,
ExtensionType::UseSRTP => 0x000e,
ExtensionType::Heartbeat => 0x000f,
ExtensionType::ALProtocolNegotiation => 0x0010,
ExtensionType::Padding => 0x0015,
ExtensionType::SessionTicket => 0x0023,
ExtensionType::NextProtocolNegotiation => 0x3374,

View File

@ -267,6 +267,35 @@ impl Codec for ServerNameRequest {
}
}
pub type ProtocolName = PayloadU8;
pub type ProtocolNameList = Vec<ProtocolName>;
impl Codec for ProtocolNameList {
fn encode(&self, bytes: &mut Vec<u8>) {
codec::encode_vec_u16(bytes, self);
}
fn read(r: &mut Reader) -> Option<ProtocolNameList> {
codec::read_vec_u16::<ProtocolName>(r)
}
}
pub trait ConvertProtocolNameList {
fn convert(names: &[String]) -> Self;
}
impl ConvertProtocolNameList for ProtocolNameList {
fn convert(names: &[String]) -> ProtocolNameList {
let mut ret = Vec::new();
for name in names {
ret.push(PayloadU8 { body: name.as_bytes().to_vec().into_boxed_slice() });
}
ret
}
}
#[derive(Debug)]
pub enum ClientExtension {
ECPointFormats(ECPointFormatList),
@ -276,6 +305,7 @@ pub enum ClientExtension {
ServerName(ServerNameRequest),
SessionTicketRequest,
SessionTicketOffer(Payload),
Protocols(ProtocolNameList),
Unknown(UnknownExtension)
}
@ -289,6 +319,7 @@ impl ClientExtension {
ClientExtension::ServerName(_) => ExtensionType::ServerName,
ClientExtension::SessionTicketRequest => ExtensionType::SessionTicket,
ClientExtension::SessionTicketOffer(_) => ExtensionType::SessionTicket,
ClientExtension::Protocols(_) => ExtensionType::ALProtocolNegotiation,
ClientExtension::Unknown(ref r) => r.typ.clone()
}
}
@ -307,6 +338,7 @@ impl Codec for ClientExtension {
ClientExtension::ServerName(ref r) => r.encode(&mut sub),
ClientExtension::SessionTicketRequest => (),
ClientExtension::SessionTicketOffer(ref r) => r.encode(&mut sub),
ClientExtension::Protocols(ref r) => r.encode(&mut sub),
ClientExtension::Unknown(ref r) => r.encode(&mut sub)
}
@ -336,6 +368,8 @@ impl Codec for ClientExtension {
} else {
ClientExtension::SessionTicketRequest
},
ExtensionType::ALProtocolNegotiation =>
ClientExtension::Protocols(try_ret!(ProtocolNameList::read(&mut sub))),
_ =>
ClientExtension::Unknown(try_ret!(UnknownExtension::read(typ, &mut sub)))
})
@ -363,6 +397,7 @@ pub enum ServerExtension {
ServerNameAcknowledgement,
SessionTicketAcknowledgement,
RenegotiationInfo(PayloadU8),
Protocols(ProtocolNameList),
Unknown(UnknownExtension)
}
@ -374,6 +409,7 @@ impl ServerExtension {
ServerExtension::ServerNameAcknowledgement => ExtensionType::ServerName,
ServerExtension::SessionTicketAcknowledgement => ExtensionType::SessionTicket,
ServerExtension::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
ServerExtension::Protocols(_) => ExtensionType::ALProtocolNegotiation,
ServerExtension::Unknown(ref r) => r.typ.clone()
}
}
@ -390,6 +426,7 @@ impl Codec for ServerExtension {
ServerExtension::ServerNameAcknowledgement => (),
ServerExtension::SessionTicketAcknowledgement => (),
ServerExtension::RenegotiationInfo(ref r) => r.encode(&mut sub),
ServerExtension::Protocols(ref r) => r.encode(&mut sub),
ServerExtension::Unknown(ref r) => r.encode(&mut sub)
}
@ -413,6 +450,8 @@ impl Codec for ServerExtension {
ServerExtension::SessionTicketAcknowledgement,
ExtensionType::RenegotiationInfo =>
ServerExtension::RenegotiationInfo(try_ret!(PayloadU8::read(&mut sub))),
ExtensionType::ALProtocolNegotiation =>
ServerExtension::Protocols(try_ret!(ProtocolNameList::read(&mut sub))),
_ =>
ServerExtension::Unknown(try_ret!(UnknownExtension::read(typ, &mut sub)))
})
@ -536,6 +575,22 @@ impl Codec for ServerHelloPayload {
}
}
impl ServerHelloPayload {
pub fn get_alpn_protocol(&self) -> Option<String> {
let ext = try_ret!(self.extensions.iter().find(|x| x.get_type() == ExtensionType::ALProtocolNegotiation));
match *ext {
ServerExtension::Protocols(ref protos) => {
if protos.len() == 1 {
String::from_utf8(protos[0].body.to_vec()).ok()
} else {
None
}
},
_ => None
}
}
}
pub type ASN1Cert = PayloadU24;
pub type CertificatePayload = Vec<ASN1Cert>;

View File

@ -15,6 +15,7 @@ pub struct TlsClient {
pub http: bool,
pub cafile: Option<String>,
pub suites: Vec<String>,
pub protos: Vec<String>,
pub verbose: bool,
pub expect_fails: bool,
pub expect_output: Option<String>,
@ -30,6 +31,7 @@ impl TlsClient {
cafile: None,
verbose: false,
suites: Vec::new(),
protos: Vec::new(),
expect_fails: false,
expect_output: None,
expect_log: None
@ -55,7 +57,7 @@ impl TlsClient {
self.expect_output = Some(expect.to_string());
self
}
pub fn expect_log(&mut self, expect: &str) -> &mut TlsClient {
self.expect_log = Some(expect.to_string());
self
@ -66,6 +68,11 @@ impl TlsClient {
self
}
pub fn proto(&mut self, proto: &str) -> &mut TlsClient {
self.protos.push(proto.to_string());
self
}
pub fn fails(&mut self) -> &mut TlsClient {
self.expect_fails = true;
self
@ -75,7 +82,7 @@ impl TlsClient {
let portstring = self.port.to_string();
let mut args = Vec::<&str>::new();
args.push(&self.hostname);
args.push("--port");
args.push(&portstring);
@ -93,6 +100,11 @@ impl TlsClient {
args.push(suite.as_ref());
}
for proto in &self.protos {
args.push("--proto");
args.push(proto.as_ref());
}
if self.verbose {
args.push("--verbose");
}
@ -110,13 +122,13 @@ impl TlsClient {
println!("{:?}", output);
panic!("Test failed");
}
if self.expect_log.is_some() && stderr_str.find(self.expect_log.as_ref().unwrap()).is_none() {
println!("We expected to find '{}' in the following output:", self.expect_log.as_ref().unwrap());
println!("{:?}", output);
panic!("Test failed");
}
if self.expect_fails {
assert!(output.status.code().unwrap() != 0);
} else {