From c101024c28c1d6d8fa5a93a5a55cad195950602e Mon Sep 17 00:00:00 2001 From: Dominik Nakamura Date: Mon, 8 Feb 2021 21:58:42 +0900 Subject: [PATCH] Add support for rustls as TLS backend (#166) * Add support for rustls as TLS backend * Use a "use-*" prefix for the TLS features * Only enable rustls if native-tls is not enabled * Allows several TLS components to coexist * Update docs for rustls mentions * Enable all features on docs.rs * Rename TLS feature flags from "use-*" to "*-tls" * Make native-tls the default * Move TLS related errors to a separate enum * Add changelog entry about rustls support * Fix wrong naming in main error enum * Simplify docs about tls feature flag usage --- CHANGELOG.md | 3 ++ Cargo.toml | 27 ++++++++++++--- README.md | 3 +- src/client.rs | 73 ++++++++++++++++++++++++++++++--------- src/error.rs | 34 +++++++++++++----- src/server.rs | 12 +++---- src/stream.rs | 15 ++++++-- tests/connection_reset.rs | 9 +++-- 8 files changed, 133 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 084b1e9..122e45e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,9 @@ - Add `CapacityError`, `UrlError`, and `ProtocolError` types to represent the different types of capacity, URL, and protocol errors respectively. - Modify variants `Error::Capacity`, `Error::Url`, and `Error::Protocol` to hold the above errors types instead of string error messages. - Add `handshake::derive_accept_key` to facilitate external handshakes. +- Add support for `rustls` as TLS backend. The previous `tls` feature flag is now removed in favor + of `native-tls` and `rustls-tls`, which allows to pick the TLS backend. The error API surface had + to be changed to support the new error types coming from rustls related crates. # 0.12.0 diff --git a/Cargo.toml b/Cargo.toml index cbab804..705c5d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,10 +12,14 @@ repository = "https://github.com/snapview/tungstenite-rs" version = "0.13.0" edition = "2018" +[package.metadata.docs.rs] +all-features = true + [features] -default = ["tls"] -tls = ["native-tls"] -tls-vendored = ["native-tls", "native-tls/vendored"] +default = ["native-tls"] +native-tls = ["native-tls-crate"] +native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] +rustls-tls = ["rustls", "webpki", "webpki-roots"] [dependencies] base64 = "0.13.0" @@ -27,14 +31,27 @@ input_buffer = "0.4.0" log = "0.4.8" rand = "0.8.0" sha-1 = "0.9" +thiserror = "1.0.23" url = "2.1.0" utf-8 = "0.7.5" -thiserror = "1.0.23" -[dependencies.native-tls] +[dependencies.native-tls-crate] optional = true +package = "native-tls" version = "0.2.3" +[dependencies.rustls] +optional = true +version = "0.19.0" + +[dependencies.webpki] +optional = true +version = "0.21.4" + +[dependencies.webpki-roots] +optional = true +version = "0.21.0" + [dev-dependencies] env_logger = "0.8.1" net2 = "0.2.33" diff --git a/README.md b/README.md index 7173582..430f3b7 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,8 @@ Features -------- Tungstenite provides a complete implementation of the WebSocket specification. -TLS is supported on all platforms using native-tls. +TLS is supported on all platforms using native-tls or rustls available through the `native-tls` +and `rustls-tls` feature flags. There is no support for permessage-deflate at the moment. It's planned. diff --git a/src/client.rs b/src/client.rs index 5ed89cf..f351cf2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,27 +16,30 @@ use crate::{ protocol::WebSocketConfig, }; -#[cfg(feature = "tls")] +#[cfg(feature = "native-tls")] mod encryption { - pub use native_tls::TlsStream; - use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector}; + pub use native_tls_crate::TlsStream; + use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector}; use std::net::TcpStream; pub use crate::stream::Stream as StreamSwitcher; /// TCP stream switcher (plain/TLS). pub type AutoStream = StreamSwitcher>; - use crate::{error::Result, stream::Mode}; + use crate::{ + error::{Result, TlsError}, + stream::Mode, + }; pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { match mode { Mode::Plain => Ok(StreamSwitcher::Plain(stream)), Mode::Tls => { - let connector = TlsConnector::builder().build()?; + let connector = TlsConnector::builder().build().map_err(TlsError::Native)?; connector .connect(domain, stream) .map_err(|e| match e { - TlsHandshakeError::Failure(f) => f.into(), + TlsHandshakeError::Failure(f) => TlsError::Native(f).into(), TlsHandshakeError::WouldBlock(_) => { panic!("Bug: TLS handshake not blocked") } @@ -47,7 +50,43 @@ mod encryption { } } -#[cfg(not(feature = "tls"))] +#[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))] +mod encryption { + use rustls::ClientConfig; + pub use rustls::{ClientSession, StreamOwned}; + use std::{net::TcpStream, sync::Arc}; + use webpki::DNSNameRef; + + pub use crate::stream::Stream as StreamSwitcher; + /// TCP stream switcher (plain/TLS). + pub type AutoStream = StreamSwitcher>; + + use crate::{ + error::{Result, TlsError}, + stream::Mode, + }; + + pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result { + match mode { + Mode::Plain => Ok(StreamSwitcher::Plain(stream)), + Mode::Tls => { + let config = { + let mut config = ClientConfig::new(); + config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + + Arc::new(config) + }; + let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?; + let client = ClientSession::new(&config, domain); + let stream = StreamOwned::new(client, stream); + + Ok(StreamSwitcher::Tls(stream)) + } + } + } +} + +#[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))] mod encryption { use std::net::TcpStream; @@ -56,7 +95,7 @@ mod encryption { stream::Mode, }; - /// TLS support is nod compiled in, this is just standard `TcpStream`. + /// TLS support is not compiled in, this is just standard `TcpStream`. pub type AutoStream = TcpStream; pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result { @@ -83,15 +122,15 @@ use crate::{ /// equal to calling `connect()` function. /// /// The URL may be either ws:// or wss://. -/// To support wss:// URLs, feature "tls" must be turned on. +/// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on. /// /// This function "just works" for those who wants a simple blocking solution /// similar to `std::net::TcpStream`. If you want a non-blocking or other /// custom stream, call `client` instead. /// -/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, -/// use `client` instead. There is no need to enable the "tls" feature if you don't call -/// `connect` since it's the only function that uses native_tls. +/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If +/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of +/// the `*-tls` features if you don't call `connect` since it's the only function that uses them. pub fn connect_with_config( request: Req, config: Option, @@ -151,15 +190,15 @@ pub fn connect_with_config( /// Connect to the given WebSocket in blocking mode. /// /// The URL may be either ws:// or wss://. -/// To support wss:// URLs, feature "tls" must be turned on. +/// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on. /// /// This function "just works" for those who wants a simple blocking solution /// similar to `std::net::TcpStream`. If you want a non-blocking or other /// custom stream, call `client` instead. /// -/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, -/// use `client` instead. There is no need to enable the "tls" feature if you don't call -/// `connect` since it's the only function that uses native_tls. +/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If +/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of +/// the `*-tls` features if you don't call `connect` since it's the only function that uses them. pub fn connect(request: Req) -> Result<(WebSocket, Response)> { connect_with_config(request, None, 3) } @@ -180,7 +219,7 @@ fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result Result { match uri.scheme_str() { Some("ws") => Ok(Mode::Plain), diff --git a/src/error.rs b/src/error.rs index f4dfdf1..f8f4081 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,12 +6,6 @@ use crate::protocol::{frame::coding::Data, Message}; use http::Response; use thiserror::Error; -#[cfg(feature = "tls")] -pub mod tls { - //! TLS error wrapper module, feature-gated. - pub use native_tls::Error; -} - /// Result type of all Tungstenite library calls. pub type Result = result::Result; @@ -45,9 +39,11 @@ pub enum Error { #[error("IO error: {0}")] Io(#[from] io::Error), /// TLS error. - #[cfg(feature = "tls")] + /// + /// Note that this error variant is enabled unconditionally even if no TLS feature is enabled, + /// to provide a feature-agnostic API surface. #[error("TLS error: {0}")] - Tls(#[from] tls::Error), + Tls(#[from] TlsError), /// - When reading: buffer capacity exhausted. /// - When writing: your message is bigger than the configured max message size /// (64MB by default). @@ -248,3 +244,25 @@ pub enum UrlError { #[error("No path/query in URL")] NoPathOrQuery, } + +/// TLS errors. +/// +/// Note that even if you enable only the rustls-based TLS support, the error at runtime could still +/// be `Native`, as another crate in the dependency graph may enable native TLS support. +#[allow(missing_copy_implementations)] +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum TlsError { + /// Native TLS error. + #[cfg(feature = "native-tls")] + #[error("native-tls error: {0}")] + Native(#[from] native_tls_crate::Error), + /// Rustls error. + #[cfg(feature = "rustls-tls")] + #[error("rustls error: {0}")] + Rustls(#[from] rustls::TLSError), + /// DNS name resolution error. + #[cfg(feature = "rustls-tls")] + #[error("Invalid DNS name: {0}")] + Dns(#[from] webpki::InvalidDNSNameError), +} diff --git a/src/server.rs b/src/server.rs index 53303ee..e79bccb 100644 --- a/src/server.rs +++ b/src/server.rs @@ -17,9 +17,9 @@ use std::io::{Read, Write}; /// used by `accept()`. /// /// This function starts a server WebSocket handshake over the given stream. -/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` -/// for the stream here. Any `Read + Write` streams are supported, including -/// those from `Mio` and others. +/// If you want TLS support, use `native_tls::TlsStream`, `rustls::Stream` or +/// `openssl::ssl::SslStream` for the stream here. Any `Read + Write` streams are supported, +/// including those from `Mio` and others. pub fn accept_with_config( stream: S, config: Option, @@ -30,9 +30,9 @@ pub fn accept_with_config( /// Accept the given Stream as a WebSocket. /// /// This function starts a server WebSocket handshake over the given stream. -/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream` -/// for the stream here. Any `Read + Write` streams are supported, including -/// those from `Mio` and others. +/// If you want TLS support, use `native_tls::TlsStream`, `rustls::Stream` or +/// `openssl::ssl::SslStream` for the stream here. Any `Read + Write` streams are supported, +/// including those from `Mio` and others. pub fn accept( stream: S, ) -> Result, HandshakeError>> { diff --git a/src/stream.rs b/src/stream.rs index 96d26d2..4d60405 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -8,8 +8,10 @@ use std::io::{Read, Result as IoResult, Write}; use std::net::TcpStream; -#[cfg(feature = "tls")] -use native_tls::TlsStream; +#[cfg(feature = "native-tls")] +use native_tls_crate::TlsStream; +#[cfg(feature = "rustls-tls")] +use rustls::StreamOwned; /// Stream mode, either plain TCP or TLS. #[derive(Clone, Copy, Debug)] @@ -32,13 +34,20 @@ impl NoDelay for TcpStream { } } -#[cfg(feature = "tls")] +#[cfg(feature = "native-tls")] impl NoDelay for TlsStream { fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.get_mut().set_nodelay(nodelay) } } +#[cfg(feature = "rustls-tls")] +impl NoDelay for StreamOwned { + fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { + self.sock.set_nodelay(nodelay) + } +} + /// Stream, either plain TCP or TLS. #[derive(Debug)] pub enum Stream { diff --git a/tests/connection_reset.rs b/tests/connection_reset.rs index 7e3e33f..7f625be 100644 --- a/tests/connection_reset.rs +++ b/tests/connection_reset.rs @@ -1,5 +1,6 @@ //! Verifies that the server returns a `ConnectionClosed` error when the connection -//! is closedd from the server's point of view and drop the underlying tcp socket. +//! is closed from the server's point of view and drop the underlying tcp socket. +#![cfg(any(feature = "native-tls", feature = "rustls-tls"))] use std::{ net::{TcpListener, TcpStream}, @@ -8,12 +9,14 @@ use std::{ time::Duration, }; -use native_tls::TlsStream; use net2::TcpStreamExt; use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket}; use url::Url; -type Sock = WebSocket>>; +#[cfg(feature = "native-tls")] +type Sock = WebSocket>>; +#[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))] +type Sock = WebSocket>>; fn do_test(port: u16, client_task: CT, server_task: ST) where