From f4bb653aa03665a83bd98068432602f07cd5634b Mon Sep 17 00:00:00 2001 From: Dominik Nakamura Date: Mon, 18 Oct 2021 11:27:03 +0900 Subject: [PATCH] Upgrade to rustls 0.20 --- Cargo.toml | 10 ++++++---- src/client.rs | 2 +- src/error.rs | 10 +++++++--- src/handshake/server.rs | 6 +++--- src/stream.rs | 15 +++++++++++---- src/tls.rs | 40 ++++++++++++++++++++++++++++------------ 6 files changed, 56 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c17264c..29cee33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,19 +44,21 @@ version = "0.2.3" [dependencies.rustls] optional = true -version = "0.19.0" +version = "0.20.0" [dependencies.rustls-native-certs] optional = true -version = "0.5.0" +version = "0.6.0" +git = "https://github.com/rustls/rustls-native-certs.git" +rev = "87b84b51bcf38eb9d377e0f5606c444ced43cc60" [dependencies.webpki] optional = true -version = "0.21" +version = "0.22" [dependencies.webpki-roots] optional = true -version = "0.21" +version = "0.22" [dev-dependencies] criterion = "0.3.4" diff --git a/src/client.rs b/src/client.rs index 67a3c41..12dfe9b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -56,7 +56,7 @@ pub fn connect_with_config( Mode::Tls => 443, }); let addrs = (host, port).to_socket_addrs()?; - let mut stream = connect_to_some(addrs.as_slice(), &request.uri())?; + let mut stream = connect_to_some(addrs.as_slice(), request.uri())?; NoDelay::set_nodelay(&mut stream, true)?; #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))] diff --git a/src/error.rs b/src/error.rs index 510e3f4..e224da7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -255,9 +255,13 @@ pub enum TlsError { /// Rustls error. #[cfg(feature = "__rustls-tls")] #[error("rustls error: {0}")] - Rustls(#[from] rustls::TLSError), + Rustls(#[from] rustls::Error), + /// Webpki error. + #[cfg(feature = "__rustls-tls")] + #[error("webpki error: {0}")] + Webpki(#[from] webpki::Error), /// DNS name resolution error. #[cfg(feature = "__rustls-tls")] - #[error("Invalid DNS name: {0}")] - Dns(#[from] webpki::InvalidDNSNameError), + #[error("Invalid DNS name")] + InvalidDnsName, } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index fddf953..5f86c91 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -82,7 +82,7 @@ fn create_parts(request: &HttpRequest) -> Result { /// Create a response for the request. pub fn create_response(request: &Request) -> Result { - Ok(create_parts(&request)?.body(())?) + Ok(create_parts(request)?.body(())?) } /// Create a response for the request with a custom body. @@ -90,7 +90,7 @@ pub fn create_response_with_body( request: &HttpRequest, generate_body: impl FnOnce() -> T, ) -> Result> { - Ok(create_parts(&request)?.body(generate_body())?) + Ok(create_parts(request)?.body(generate_body())?) } // Assumes that this is a valid response @@ -263,7 +263,7 @@ impl HandshakeRole for ServerHandshake { let resp = self.error_response.as_ref().unwrap(); let mut output = vec![]; - write_response(&mut output, &resp)?; + write_response(&mut output, resp)?; if let Some(body) = resp.body() { output.extend_from_slice(body.as_bytes()); diff --git a/src/stream.rs b/src/stream.rs index b7fe0e4..4775230 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,6 +4,8 @@ //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `Read + Write` traits. +#[cfg(feature = "__rustls-tls")] +use std::ops::Deref; use std::{ fmt::{self, Debug}, io::{Read, Result as IoResult, Write}, @@ -45,7 +47,12 @@ impl NoDelay for TlsStream { } #[cfg(feature = "__rustls-tls")] -impl NoDelay for StreamOwned { +impl NoDelay for StreamOwned +where + S: Deref>, + SD: rustls::SideData, + T: Read + Write + NoDelay, +{ fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.sock.set_nodelay(nodelay) } @@ -61,7 +68,7 @@ pub enum MaybeTlsStream { NativeTls(native_tls_crate::TlsStream), #[cfg(feature = "__rustls-tls")] /// Encrypted socket stream using `rustls`. - Rustls(rustls::StreamOwned), + Rustls(rustls::StreamOwned), } impl Debug for MaybeTlsStream { @@ -73,13 +80,13 @@ impl Debug for MaybeTlsStream { #[cfg(feature = "__rustls-tls")] Self::Rustls(s) => { struct RustlsStreamDebug<'a, S: Read + Write>( - &'a rustls::StreamOwned, + &'a rustls::StreamOwned, ); impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("StreamOwned") - .field("sess", &self.0.sess) + .field("conn", &self.0.conn) .field("sock", &self.0.sock) .finish() } diff --git a/src/tls.rs b/src/tls.rs index 4f07a54..ad54de3 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -70,10 +70,10 @@ mod encryption { #[cfg(feature = "__rustls-tls")] pub mod rustls { - use rustls::{ClientConfig, ClientSession, StreamOwned}; - use webpki::DNSNameRef; + use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned}; use std::{ + convert::TryFrom, io::{Read, Write}, sync::Arc, }; @@ -100,24 +100,40 @@ mod encryption { Some(config) => config, None => { #[allow(unused_mut)] - let mut config = ClientConfig::new(); + let mut root_store = RootCertStore::empty(); + #[cfg(feature = "rustls-tls-native-roots")] { - config.root_store = rustls_native_certs::load_native_certs() - .map_err(|(_, err)| err)?; + for cert in rustls_native_certs::load_native_certs()? { + root_store + .add(&rustls::Certificate(cert.0)) + .map_err(TlsError::Webpki)?; + } } #[cfg(feature = "rustls-tls-webpki-roots")] { - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + root_store.add_server_trust_anchors( + webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }) + ); } - Arc::new(config) + Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(), + ) } }; - let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?; - let client = ClientSession::new(&config, domain); + let domain = + ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?; + let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?; let stream = StreamOwned::new(client, socket); Ok(MaybeTlsStream::Rustls(stream)) @@ -185,7 +201,7 @@ where None => Err(Error::Url(UrlError::NoHostName)), }?; - let mode = uri_mode(&request.uri())?; + let mode = uri_mode(request.uri())?; let stream = match connector { Some(conn) => match conn {