Add a connector to configure TLS config

This commit is contained in:
Dominik Nakamura 2021-07-23 21:26:44 +09:00
parent bc1b88b820
commit 32450ae5af
No known key found for this signature in database
GPG Key ID: E4C6A749B2491910
8 changed files with 319 additions and 161 deletions

View File

@ -19,8 +19,9 @@ all-features = true
default = []
native-tls = ["native-tls-crate"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["rustls", "webpki", "rustls-native-certs"]
rustls-tls-webpki-roots = ["rustls", "webpki", "webpki-roots"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls", "webpki"]
[dependencies]
base64 = "0.13.0"

View File

@ -1,8 +1,8 @@
use std::{net::TcpListener, thread::spawn};
use tungstenite::{
accept_hdr_with_config,
handshake::server::{Request, Response},
protocol::WebSocketConfig,
server::accept_hdr_with_config,
};
fn main() {

View File

@ -14,118 +14,9 @@ use url::Url;
use crate::{
handshake::client::{Request, Response},
protocol::WebSocketConfig,
stream::MaybeTlsStream,
};
#[cfg(feature = "native-tls")]
mod encryption {
pub use native_tls_crate::TlsStream;
use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
use std::net::TcpStream;
pub use crate::stream::Stream as StreamSwitcher;
/// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
use crate::{
error::{Result, TlsError},
stream::Mode,
};
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
let connector = TlsConnector::builder().build().map_err(TlsError::Native)?;
connector
.connect(domain, stream)
.map_err(|e| match e {
TlsHandshakeError::Failure(f) => TlsError::Native(f).into(),
TlsHandshakeError::WouldBlock(_) => {
panic!("Bug: TLS handshake not blocked")
}
})
.map(StreamSwitcher::Tls)
}
}
}
}
#[cfg(all(
any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"),
not(feature = "native-tls")
))]
mod encryption {
use rustls::ClientConfig;
pub use rustls::{ClientSession, StreamOwned};
use std::{net::TcpStream, sync::Arc};
use webpki::DNSNameRef;
pub use crate::stream::Stream as StreamSwitcher;
/// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, StreamOwned<ClientSession, TcpStream>>;
use crate::{
error::{Result, TlsError},
stream::Mode,
};
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
let config = {
#[allow(unused_mut)]
let mut config = ClientConfig::new();
#[cfg(feature = "rustls-tls-native-roots")]
{
config.root_store =
rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?;
}
#[cfg(feature = "rustls-tls-webpki-roots")]
{
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
}
Arc::new(config)
};
let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?;
let client = ClientSession::new(&config, domain);
let stream = StreamOwned::new(client, stream);
Ok(StreamSwitcher::Tls(stream))
}
}
}
}
#[cfg(not(any(
feature = "native-tls",
feature = "rustls-tls-native-roots",
feature = "rustls-tls-webpki-roots"
)))]
mod encryption {
use std::net::TcpStream;
use crate::{
error::{Error, Result, UrlError},
stream::Mode,
};
/// TLS support is not compiled in, this is just standard `TcpStream`.
pub type AutoStream = TcpStream;
pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(stream),
Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
}
}
}
use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use crate::{
error::{Error, Result, UrlError},
handshake::{client::ClientHandshake, HandshakeError},
@ -152,11 +43,11 @@ pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
max_redirects: u8,
) -> Result<(WebSocket<AutoStream>, Response)> {
) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
fn try_client_handshake(
request: Request,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
@ -165,9 +56,15 @@ pub fn connect_with_config<Req: IntoClientRequest>(
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri())?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
let client = crate::tls::client_tls_with_config(request, stream, config, None);
client.map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
@ -216,18 +113,17 @@ pub fn connect_with_config<Req: IntoClientRequest>(
/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
pub fn connect<Req: IntoClientRequest>(
request: Req,
) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
connect_with_config(request, None, 3)
}
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri.host().ok_or(Error::Url(UrlError::NoHostName))?;
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
for addr in addrs {
debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream);
}
if let Ok(stream) = TcpStream::connect(addr) {
return Ok(stream);
}
}
Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))

View File

@ -7,7 +7,7 @@ use http::Response;
use thiserror::Error;
/// Result type of all Tungstenite library calls.
pub type Result<T> = result::Result<T, Error>;
pub type Result<T, E = Error> = result::Result<T, E>;
/// Possible WebSocket errors.
#[derive(Error, Debug)]
@ -253,11 +253,11 @@ pub enum TlsError {
#[error("native-tls error: {0}")]
Native(#[from] native_tls_crate::Error),
/// Rustls error.
#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))]
#[cfg(feature = "__rustls-tls")]
#[error("rustls error: {0}")]
Rustls(#[from] rustls::TLSError),
/// DNS name resolution error.
#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))]
#[cfg(feature = "__rustls-tls")]
#[error("Invalid DNS name: {0}")]
Dns(#[from] webpki::InvalidDNSNameError),
}

View File

@ -19,8 +19,10 @@ pub mod client;
pub mod error;
pub mod handshake;
pub mod protocol;
pub mod server;
mod server;
pub mod stream;
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
mod tls;
pub mod util;
const READ_BUFFER_CHUNK_SIZE: usize = 4096;
@ -31,5 +33,8 @@ pub use crate::{
error::{Error, Result},
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
protocol::{Message, WebSocket},
server::{accept, accept_hdr},
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
};
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub use tls::{client_tls, client_tls_with_config, Connector};

View File

@ -4,13 +4,16 @@
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits.
use std::io::{Read, Result as IoResult, Write};
use std::{
fmt::{self, Debug},
io::{Read, Result as IoResult, Write},
};
use std::net::TcpStream;
#[cfg(feature = "native-tls")]
use native_tls_crate::TlsStream;
#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))]
#[cfg(feature = "__rustls-tls")]
use rustls::StreamOwned;
/// Stream mode, either plain TCP or TLS.
@ -41,51 +44,95 @@ impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
}
}
#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))]
#[cfg(feature = "__rustls-tls")]
impl<S: rustls::Session, T: Read + Write + NoDelay> NoDelay for StreamOwned<S, T> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.sock.set_nodelay(nodelay)
}
}
/// Stream, either plain TCP or TLS.
#[derive(Debug)]
pub enum Stream<S, T> {
/// A stream that might be protected with TLS.
#[non_exhaustive]
pub enum MaybeTlsStream<S: Read + Write> {
/// Unencrypted socket stream.
Plain(S),
/// Encrypted socket stream.
Tls(T),
#[cfg(feature = "native-tls")]
/// Encrypted socket stream using `native-tls`.
NativeTls(native_tls_crate::TlsStream<S>),
#[cfg(feature = "__rustls-tls")]
/// Encrypted socket stream using `rustls`.
Rustls(rustls::StreamOwned<rustls::ClientSession, S>),
}
impl<S: Read, T: Read> Read for Stream<S, T> {
impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Plain(s) => f.debug_tuple("MaybeTlsStream::Plain").field(s).finish(),
#[cfg(feature = "native-tls")]
Self::NativeTls(s) => f.debug_tuple("MaybeTlsStream::NativeTls").field(s).finish(),
#[cfg(feature = "__rustls-tls")]
Self::Rustls(s) => {
struct RustlsStreamDebug<'a, S: Read + Write>(
&'a rustls::StreamOwned<rustls::ClientSession, S>,
);
impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamOwned")
.field("sess", &self.0.sess)
.field("sock", &self.0.sock)
.finish()
}
}
f.debug_tuple("MaybeTlsStream::Rustls").field(&RustlsStreamDebug(s)).finish()
}
}
}
}
impl<S: Read + Write> Read for MaybeTlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.read(buf),
Stream::Tls(ref mut s) => s.read(buf),
MaybeTlsStream::Plain(ref mut s) => s.read(buf),
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(ref mut s) => s.read(buf),
#[cfg(feature = "__rustls-tls")]
MaybeTlsStream::Rustls(ref mut s) => s.read(buf),
}
}
}
impl<S: Write, T: Write> Write for Stream<S, T> {
impl<S: Read + Write> Write for MaybeTlsStream<S> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.write(buf),
Stream::Tls(ref mut s) => s.write(buf),
MaybeTlsStream::Plain(ref mut s) => s.write(buf),
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(ref mut s) => s.write(buf),
#[cfg(feature = "__rustls-tls")]
MaybeTlsStream::Rustls(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
Stream::Plain(ref mut s) => s.flush(),
Stream::Tls(ref mut s) => s.flush(),
MaybeTlsStream::Plain(ref mut s) => s.flush(),
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(ref mut s) => s.flush(),
#[cfg(feature = "__rustls-tls")]
MaybeTlsStream::Rustls(ref mut s) => s.flush(),
}
}
}
impl<S: NoDelay, T: NoDelay> NoDelay for Stream<S, T> {
impl<S: Read + Write + NoDelay> NoDelay for MaybeTlsStream<S> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
match *self {
Stream::Plain(ref mut s) => s.set_nodelay(nodelay),
Stream::Tls(ref mut s) => s.set_nodelay(nodelay),
MaybeTlsStream::Plain(ref mut s) => s.set_nodelay(nodelay),
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(ref mut s) => s.set_nodelay(nodelay),
#[cfg(feature = "__rustls-tls")]
MaybeTlsStream::Rustls(ref mut s) => s.set_nodelay(nodelay),
}
}
}

219
src/tls.rs Normal file
View File

@ -0,0 +1,219 @@
//! Connection helper.
use std::io::{Read, Write};
use crate::{
client::{client_with_config, uri_mode, IntoClientRequest},
error::UrlError,
handshake::client::Response,
protocol::WebSocketConfig,
stream::MaybeTlsStream,
ClientHandshake, Error, HandshakeError, Result, WebSocket,
};
/// A connector that can be used when establishing connections, allowing to control whether
/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
/// `Plain` variant.
#[non_exhaustive]
#[allow(missing_debug_implementations)]
pub enum Connector {
/// Plain (non-TLS) connector.
Plain,
/// `native-tls` TLS connector.
#[cfg(feature = "native-tls")]
NativeTls(native_tls_crate::TlsConnector),
/// `rustls` TLS connector.
#[cfg(feature = "__rustls-tls")]
Rustls(std::sync::Arc<rustls::ClientConfig>),
}
mod encryption {
#[cfg(feature = "native-tls")]
pub mod native_tls {
use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
use std::io::{Read, Write};
use crate::{
error::TlsError,
stream::{MaybeTlsStream, Mode},
Error, Result,
};
pub fn wrap_stream<S>(
socket: S,
domain: &str,
mode: Mode,
tls_connector: Option<TlsConnector>,
) -> Result<MaybeTlsStream<S>>
where
S: Read + Write,
{
match mode {
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
Mode::Tls => {
let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
let connector = try_connector.map_err(TlsError::Native)?;
let connected = connector.connect(domain, socket);
match connected {
Err(e) => match e {
TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
TlsHandshakeError::WouldBlock(_) => {
panic!("Bug: TLS handshake not blocked")
}
},
Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
}
}
}
}
}
#[cfg(feature = "__rustls-tls")]
pub mod rustls {
use rustls::{ClientConfig, ClientSession, StreamOwned};
use webpki::DNSNameRef;
use std::{
io::{Read, Write},
sync::Arc,
};
use crate::{
error::TlsError,
stream::{MaybeTlsStream, Mode},
Result,
};
pub fn wrap_stream<S>(
socket: S,
domain: &str,
mode: Mode,
tls_connector: Option<Arc<ClientConfig>>,
) -> Result<MaybeTlsStream<S>>
where
S: Read + Write,
{
match mode {
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
Mode::Tls => {
let config = match tls_connector {
Some(config) => config,
None => {
#[allow(unused_mut)]
let mut config = ClientConfig::new();
#[cfg(feature = "rustls-tls-native-roots")]
{
config.root_store = rustls_native_certs::load_native_certs()
.map_err(|(_, err)| err)?;
}
#[cfg(feature = "rustls-tls-webpki-roots")]
{
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
}
Arc::new(config)
}
};
let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?;
let client = ClientSession::new(&config, domain);
let stream = StreamOwned::new(client, socket);
Ok(MaybeTlsStream::Rustls(stream))
}
}
}
}
pub mod plain {
use std::io::{Read, Write};
use crate::{
error::UrlError,
stream::{MaybeTlsStream, Mode},
Error, Result,
};
pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
where
S: Read + Write,
{
match mode {
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
}
}
}
}
type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required.
pub fn client_tls<R, S>(
request: R,
stream: S,
) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
where
R: IntoClientRequest,
S: Read + Write,
{
client_tls_with_config(request, stream, None, None)
}
/// The same as [`client_tls()`] but one can specify a websocket configuration,
/// and an optional connector. If no connector is specified, a default one will
/// be created.
///
/// Please refer to [`client_tls()`] for more details.
pub fn client_tls_with_config<R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
connector: Option<Connector>,
) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
where
R: IntoClientRequest,
S: Read + Write,
{
let request = request.into_client_request()?;
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
let domain = match request.uri().host() {
Some(d) => Ok(d.to_string()),
None => Err(Error::Url(UrlError::NoHostName)),
}?;
let mode = uri_mode(&request.uri())?;
let stream = match connector {
Some(conn) => match conn {
#[cfg(feature = "native-tls")]
Connector::NativeTls(conn) => {
self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
}
#[cfg(feature = "__rustls-tls")]
Connector::Rustls(conn) => {
self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
}
Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
},
None => {
#[cfg(feature = "native-tls")]
{
self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
}
#[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
{
self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
}
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
{
self::encryption::plain::wrap_stream(stream, mode)
}
}
}?;
client_with_config(request, stream, config)
}

View File

@ -1,10 +1,6 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closed from the server's point of view and drop the underlying tcp socket.
#![cfg(any(
feature = "native-tls",
feature = "rustls-tls-native-roots",
feature = "rustls-tls-webpki-roots"
))]
#![cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
use std::{
net::{TcpListener, TcpStream},
@ -14,16 +10,10 @@ use std::{
};
use net2::TcpStreamExt;
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
use tungstenite::{accept, connect, stream::MaybeTlsStream, Error, Message, WebSocket};
use url::Url;
#[cfg(feature = "native-tls")]
type Sock = WebSocket<Stream<TcpStream, native_tls_crate::TlsStream<TcpStream>>>;
#[cfg(all(
any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"),
not(feature = "native-tls")
))]
type Sock = WebSocket<Stream<TcpStream, rustls::StreamOwned<rustls::ClientSession, TcpStream>>>;
type Sock = WebSocket<MaybeTlsStream<TcpStream>>;
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
where