Upgrade to rustls 0.20

This commit is contained in:
Dominik Nakamura 2021-10-18 11:27:03 +09:00
parent 89697449ff
commit f4bb653aa0
No known key found for this signature in database
GPG Key ID: E4C6A749B2491910
6 changed files with 56 additions and 27 deletions

View File

@ -44,19 +44,21 @@ version = "0.2.3"
[dependencies.rustls] [dependencies.rustls]
optional = true optional = true
version = "0.19.0" version = "0.20.0"
[dependencies.rustls-native-certs] [dependencies.rustls-native-certs]
optional = true optional = true
version = "0.5.0" version = "0.6.0"
git = "https://github.com/rustls/rustls-native-certs.git"
rev = "87b84b51bcf38eb9d377e0f5606c444ced43cc60"
[dependencies.webpki] [dependencies.webpki]
optional = true optional = true
version = "0.21" version = "0.22"
[dependencies.webpki-roots] [dependencies.webpki-roots]
optional = true optional = true
version = "0.21" version = "0.22"
[dev-dependencies] [dev-dependencies]
criterion = "0.3.4" criterion = "0.3.4"

View File

@ -56,7 +56,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
Mode::Tls => 443, Mode::Tls => 443,
}); });
let addrs = (host, port).to_socket_addrs()?; let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri())?; let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
NoDelay::set_nodelay(&mut stream, true)?; NoDelay::set_nodelay(&mut stream, true)?;
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))] #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]

View File

@ -255,9 +255,13 @@ pub enum TlsError {
/// Rustls error. /// Rustls error.
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
#[error("rustls error: {0}")] #[error("rustls error: {0}")]
Rustls(#[from] rustls::TLSError), Rustls(#[from] rustls::Error),
/// Webpki error.
#[cfg(feature = "__rustls-tls")]
#[error("webpki error: {0}")]
Webpki(#[from] webpki::Error),
/// DNS name resolution error. /// DNS name resolution error.
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
#[error("Invalid DNS name: {0}")] #[error("Invalid DNS name")]
Dns(#[from] webpki::InvalidDNSNameError), InvalidDnsName,
} }

View File

@ -82,7 +82,7 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
/// Create a response for the request. /// Create a response for the request.
pub fn create_response(request: &Request) -> Result<Response> { pub fn create_response(request: &Request) -> Result<Response> {
Ok(create_parts(&request)?.body(())?) Ok(create_parts(request)?.body(())?)
} }
/// Create a response for the request with a custom body. /// Create a response for the request with a custom body.
@ -90,7 +90,7 @@ pub fn create_response_with_body<T>(
request: &HttpRequest<T>, request: &HttpRequest<T>,
generate_body: impl FnOnce() -> T, generate_body: impl FnOnce() -> T,
) -> Result<HttpResponse<T>> { ) -> Result<HttpResponse<T>> {
Ok(create_parts(&request)?.body(generate_body())?) Ok(create_parts(request)?.body(generate_body())?)
} }
// Assumes that this is a valid response // Assumes that this is a valid response
@ -263,7 +263,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
let resp = self.error_response.as_ref().unwrap(); let resp = self.error_response.as_ref().unwrap();
let mut output = vec![]; let mut output = vec![];
write_response(&mut output, &resp)?; write_response(&mut output, resp)?;
if let Some(body) = resp.body() { if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes()); output.extend_from_slice(body.as_bytes());

View File

@ -4,6 +4,8 @@
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits. //! `Read + Write` traits.
#[cfg(feature = "__rustls-tls")]
use std::ops::Deref;
use std::{ use std::{
fmt::{self, Debug}, fmt::{self, Debug},
io::{Read, Result as IoResult, Write}, io::{Read, Result as IoResult, Write},
@ -45,7 +47,12 @@ impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
} }
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
impl<S: rustls::Session, T: Read + Write + NoDelay> NoDelay for StreamOwned<S, T> { impl<S, SD, T> NoDelay for StreamOwned<S, T>
where
S: Deref<Target = rustls::ConnectionCommon<SD>>,
SD: rustls::SideData,
T: Read + Write + NoDelay,
{
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.sock.set_nodelay(nodelay) self.sock.set_nodelay(nodelay)
} }
@ -61,7 +68,7 @@ pub enum MaybeTlsStream<S: Read + Write> {
NativeTls(native_tls_crate::TlsStream<S>), NativeTls(native_tls_crate::TlsStream<S>),
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
/// Encrypted socket stream using `rustls`. /// Encrypted socket stream using `rustls`.
Rustls(rustls::StreamOwned<rustls::ClientSession, S>), Rustls(rustls::StreamOwned<rustls::ClientConnection, S>),
} }
impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> { impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
@ -73,13 +80,13 @@ impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
Self::Rustls(s) => { Self::Rustls(s) => {
struct RustlsStreamDebug<'a, S: Read + Write>( struct RustlsStreamDebug<'a, S: Read + Write>(
&'a rustls::StreamOwned<rustls::ClientSession, S>, &'a rustls::StreamOwned<rustls::ClientConnection, S>,
); );
impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> { impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamOwned") f.debug_struct("StreamOwned")
.field("sess", &self.0.sess) .field("conn", &self.0.conn)
.field("sock", &self.0.sock) .field("sock", &self.0.sock)
.finish() .finish()
} }

View File

@ -70,10 +70,10 @@ mod encryption {
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
pub mod rustls { pub mod rustls {
use rustls::{ClientConfig, ClientSession, StreamOwned}; use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned};
use webpki::DNSNameRef;
use std::{ use std::{
convert::TryFrom,
io::{Read, Write}, io::{Read, Write},
sync::Arc, sync::Arc,
}; };
@ -100,24 +100,40 @@ mod encryption {
Some(config) => config, Some(config) => config,
None => { None => {
#[allow(unused_mut)] #[allow(unused_mut)]
let mut config = ClientConfig::new(); let mut root_store = RootCertStore::empty();
#[cfg(feature = "rustls-tls-native-roots")] #[cfg(feature = "rustls-tls-native-roots")]
{ {
config.root_store = rustls_native_certs::load_native_certs() for cert in rustls_native_certs::load_native_certs()? {
.map_err(|(_, err)| err)?; root_store
.add(&rustls::Certificate(cert.0))
.map_err(TlsError::Webpki)?;
}
} }
#[cfg(feature = "rustls-tls-webpki-roots")] #[cfg(feature = "rustls-tls-webpki-roots")]
{ {
config root_store.add_server_trust_anchors(
.root_store webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
);
} }
Arc::new(config) Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth(),
)
} }
}; };
let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?; let domain =
let client = ClientSession::new(&config, domain); ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?;
let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
let stream = StreamOwned::new(client, socket); let stream = StreamOwned::new(client, socket);
Ok(MaybeTlsStream::Rustls(stream)) Ok(MaybeTlsStream::Rustls(stream))
@ -185,7 +201,7 @@ where
None => Err(Error::Url(UrlError::NoHostName)), None => Err(Error::Url(UrlError::NoHostName)),
}?; }?;
let mode = uri_mode(&request.uri())?; let mode = uri_mode(request.uri())?;
let stream = match connector { let stream = match connector {
Some(conn) => match conn { Some(conn) => match conn {