Merge pull request #226 from dnaka91/connector
Add a connector to configure TLS config
This commit is contained in:
commit
239f8e293f
|
@ -19,8 +19,9 @@ all-features = true
|
|||
default = []
|
||||
native-tls = ["native-tls-crate"]
|
||||
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
|
||||
rustls-tls-native-roots = ["rustls", "webpki", "rustls-native-certs"]
|
||||
rustls-tls-webpki-roots = ["rustls", "webpki", "webpki-roots"]
|
||||
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
|
||||
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
|
||||
__rustls-tls = ["rustls", "webpki"]
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.13.0"
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::{net::TcpListener, thread::spawn};
|
||||
use tungstenite::{
|
||||
accept_hdr_with_config,
|
||||
handshake::server::{Request, Response},
|
||||
protocol::WebSocketConfig,
|
||||
server::accept_hdr_with_config,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
|
|
138
src/client.rs
138
src/client.rs
|
@ -14,118 +14,9 @@ use url::Url;
|
|||
use crate::{
|
||||
handshake::client::{Request, Response},
|
||||
protocol::WebSocketConfig,
|
||||
stream::MaybeTlsStream,
|
||||
};
|
||||
|
||||
#[cfg(feature = "native-tls")]
|
||||
mod encryption {
|
||||
pub use native_tls_crate::TlsStream;
|
||||
use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
|
||||
use std::net::TcpStream;
|
||||
|
||||
pub use crate::stream::Stream as StreamSwitcher;
|
||||
/// TCP stream switcher (plain/TLS).
|
||||
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
|
||||
|
||||
use crate::{
|
||||
error::{Result, TlsError},
|
||||
stream::Mode,
|
||||
};
|
||||
|
||||
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
|
||||
match mode {
|
||||
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
|
||||
Mode::Tls => {
|
||||
let connector = TlsConnector::builder().build().map_err(TlsError::Native)?;
|
||||
connector
|
||||
.connect(domain, stream)
|
||||
.map_err(|e| match e {
|
||||
TlsHandshakeError::Failure(f) => TlsError::Native(f).into(),
|
||||
TlsHandshakeError::WouldBlock(_) => {
|
||||
panic!("Bug: TLS handshake not blocked")
|
||||
}
|
||||
})
|
||||
.map(StreamSwitcher::Tls)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(
|
||||
any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"),
|
||||
not(feature = "native-tls")
|
||||
))]
|
||||
mod encryption {
|
||||
use rustls::ClientConfig;
|
||||
pub use rustls::{ClientSession, StreamOwned};
|
||||
use std::{net::TcpStream, sync::Arc};
|
||||
use webpki::DNSNameRef;
|
||||
|
||||
pub use crate::stream::Stream as StreamSwitcher;
|
||||
/// TCP stream switcher (plain/TLS).
|
||||
pub type AutoStream = StreamSwitcher<TcpStream, StreamOwned<ClientSession, TcpStream>>;
|
||||
|
||||
use crate::{
|
||||
error::{Result, TlsError},
|
||||
stream::Mode,
|
||||
};
|
||||
|
||||
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
|
||||
match mode {
|
||||
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
|
||||
Mode::Tls => {
|
||||
let config = {
|
||||
#[allow(unused_mut)]
|
||||
let mut config = ClientConfig::new();
|
||||
#[cfg(feature = "rustls-tls-native-roots")]
|
||||
{
|
||||
config.root_store =
|
||||
rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?;
|
||||
}
|
||||
#[cfg(feature = "rustls-tls-webpki-roots")]
|
||||
{
|
||||
config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
||||
}
|
||||
|
||||
Arc::new(config)
|
||||
};
|
||||
|
||||
let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?;
|
||||
let client = ClientSession::new(&config, domain);
|
||||
let stream = StreamOwned::new(client, stream);
|
||||
|
||||
Ok(StreamSwitcher::Tls(stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
feature = "native-tls",
|
||||
feature = "rustls-tls-native-roots",
|
||||
feature = "rustls-tls-webpki-roots"
|
||||
)))]
|
||||
mod encryption {
|
||||
use std::net::TcpStream;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result, UrlError},
|
||||
stream::Mode,
|
||||
};
|
||||
|
||||
/// TLS support is not compiled in, this is just standard `TcpStream`.
|
||||
pub type AutoStream = TcpStream;
|
||||
|
||||
pub fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
|
||||
match mode {
|
||||
Mode::Plain => Ok(stream),
|
||||
Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use self::encryption::wrap_stream;
|
||||
pub use self::encryption::AutoStream;
|
||||
|
||||
use crate::{
|
||||
error::{Error, Result, UrlError},
|
||||
handshake::{client::ClientHandshake, HandshakeError},
|
||||
|
@ -152,11 +43,11 @@ pub fn connect_with_config<Req: IntoClientRequest>(
|
|||
request: Req,
|
||||
config: Option<WebSocketConfig>,
|
||||
max_redirects: u8,
|
||||
) -> Result<(WebSocket<AutoStream>, Response)> {
|
||||
) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
|
||||
fn try_client_handshake(
|
||||
request: Request,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Result<(WebSocket<AutoStream>, Response)> {
|
||||
) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
|
||||
let uri = request.uri();
|
||||
let mode = uri_mode(uri)?;
|
||||
let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
|
||||
|
@ -165,9 +56,15 @@ pub fn connect_with_config<Req: IntoClientRequest>(
|
|||
Mode::Tls => 443,
|
||||
});
|
||||
let addrs = (host, port).to_socket_addrs()?;
|
||||
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
|
||||
let mut stream = connect_to_some(addrs.as_slice(), &request.uri())?;
|
||||
NoDelay::set_nodelay(&mut stream, true)?;
|
||||
client_with_config(request, stream, config).map_err(|e| match e {
|
||||
|
||||
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
|
||||
let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
|
||||
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
|
||||
let client = crate::tls::client_tls_with_config(request, stream, config, None);
|
||||
|
||||
client.map_err(|e| match e {
|
||||
HandshakeError::Failure(f) => f,
|
||||
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
|
||||
})
|
||||
|
@ -216,18 +113,17 @@ pub fn connect_with_config<Req: IntoClientRequest>(
|
|||
/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
|
||||
/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
|
||||
/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
|
||||
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
|
||||
pub fn connect<Req: IntoClientRequest>(
|
||||
request: Req,
|
||||
) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
|
||||
connect_with_config(request, None, 3)
|
||||
}
|
||||
|
||||
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
|
||||
let domain = uri.host().ok_or(Error::Url(UrlError::NoHostName))?;
|
||||
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
|
||||
for addr in addrs {
|
||||
debug!("Trying to contact {} at {}...", uri, addr);
|
||||
if let Ok(raw_stream) = TcpStream::connect(addr) {
|
||||
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
|
||||
return Ok(stream);
|
||||
}
|
||||
if let Ok(stream) = TcpStream::connect(addr) {
|
||||
return Ok(stream);
|
||||
}
|
||||
}
|
||||
Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
|
||||
|
|
|
@ -7,7 +7,7 @@ use http::Response;
|
|||
use thiserror::Error;
|
||||
|
||||
/// Result type of all Tungstenite library calls.
|
||||
pub type Result<T> = result::Result<T, Error>;
|
||||
pub type Result<T, E = Error> = result::Result<T, E>;
|
||||
|
||||
/// Possible WebSocket errors.
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -253,11 +253,11 @@ pub enum TlsError {
|
|||
#[error("native-tls error: {0}")]
|
||||
Native(#[from] native_tls_crate::Error),
|
||||
/// Rustls error.
|
||||
#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))]
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
#[error("rustls error: {0}")]
|
||||
Rustls(#[from] rustls::TLSError),
|
||||
/// DNS name resolution error.
|
||||
#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))]
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
#[error("Invalid DNS name: {0}")]
|
||||
Dns(#[from] webpki::InvalidDNSNameError),
|
||||
}
|
||||
|
|
|
@ -19,8 +19,10 @@ pub mod client;
|
|||
pub mod error;
|
||||
pub mod handshake;
|
||||
pub mod protocol;
|
||||
pub mod server;
|
||||
mod server;
|
||||
pub mod stream;
|
||||
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
|
||||
mod tls;
|
||||
pub mod util;
|
||||
|
||||
const READ_BUFFER_CHUNK_SIZE: usize = 4096;
|
||||
|
@ -31,5 +33,8 @@ pub use crate::{
|
|||
error::{Error, Result},
|
||||
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
|
||||
protocol::{Message, WebSocket},
|
||||
server::{accept, accept_hdr},
|
||||
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
|
||||
};
|
||||
|
||||
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
|
||||
pub use tls::{client_tls, client_tls_with_config, Connector};
|
||||
|
|
|
@ -4,13 +4,16 @@
|
|||
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
|
||||
//! `Read + Write` traits.
|
||||
|
||||
use std::io::{Read, Result as IoResult, Write};
|
||||
use std::{
|
||||
fmt::{self, Debug},
|
||||
io::{Read, Result as IoResult, Write},
|
||||
};
|
||||
|
||||
use std::net::TcpStream;
|
||||
|
||||
#[cfg(feature = "native-tls")]
|
||||
use native_tls_crate::TlsStream;
|
||||
#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))]
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
use rustls::StreamOwned;
|
||||
|
||||
/// Stream mode, either plain TCP or TLS.
|
||||
|
@ -41,51 +44,95 @@ impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"))]
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
impl<S: rustls::Session, T: Read + Write + NoDelay> NoDelay for StreamOwned<S, T> {
|
||||
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
|
||||
self.sock.set_nodelay(nodelay)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream, either plain TCP or TLS.
|
||||
#[derive(Debug)]
|
||||
pub enum Stream<S, T> {
|
||||
/// A stream that might be protected with TLS.
|
||||
#[non_exhaustive]
|
||||
pub enum MaybeTlsStream<S: Read + Write> {
|
||||
/// Unencrypted socket stream.
|
||||
Plain(S),
|
||||
/// Encrypted socket stream.
|
||||
Tls(T),
|
||||
#[cfg(feature = "native-tls")]
|
||||
/// Encrypted socket stream using `native-tls`.
|
||||
NativeTls(native_tls_crate::TlsStream<S>),
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
/// Encrypted socket stream using `rustls`.
|
||||
Rustls(rustls::StreamOwned<rustls::ClientSession, S>),
|
||||
}
|
||||
|
||||
impl<S: Read, T: Read> Read for Stream<S, T> {
|
||||
impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Plain(s) => f.debug_tuple("MaybeTlsStream::Plain").field(s).finish(),
|
||||
#[cfg(feature = "native-tls")]
|
||||
Self::NativeTls(s) => f.debug_tuple("MaybeTlsStream::NativeTls").field(s).finish(),
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
Self::Rustls(s) => {
|
||||
struct RustlsStreamDebug<'a, S: Read + Write>(
|
||||
&'a rustls::StreamOwned<rustls::ClientSession, S>,
|
||||
);
|
||||
|
||||
impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("StreamOwned")
|
||||
.field("sess", &self.0.sess)
|
||||
.field("sock", &self.0.sock)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
f.debug_tuple("MaybeTlsStream::Rustls").field(&RustlsStreamDebug(s)).finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read + Write> Read for MaybeTlsStream<S> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
|
||||
match *self {
|
||||
Stream::Plain(ref mut s) => s.read(buf),
|
||||
Stream::Tls(ref mut s) => s.read(buf),
|
||||
MaybeTlsStream::Plain(ref mut s) => s.read(buf),
|
||||
#[cfg(feature = "native-tls")]
|
||||
MaybeTlsStream::NativeTls(ref mut s) => s.read(buf),
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
MaybeTlsStream::Rustls(ref mut s) => s.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Write, T: Write> Write for Stream<S, T> {
|
||||
impl<S: Read + Write> Write for MaybeTlsStream<S> {
|
||||
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
|
||||
match *self {
|
||||
Stream::Plain(ref mut s) => s.write(buf),
|
||||
Stream::Tls(ref mut s) => s.write(buf),
|
||||
MaybeTlsStream::Plain(ref mut s) => s.write(buf),
|
||||
#[cfg(feature = "native-tls")]
|
||||
MaybeTlsStream::NativeTls(ref mut s) => s.write(buf),
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
MaybeTlsStream::Rustls(ref mut s) => s.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> IoResult<()> {
|
||||
match *self {
|
||||
Stream::Plain(ref mut s) => s.flush(),
|
||||
Stream::Tls(ref mut s) => s.flush(),
|
||||
MaybeTlsStream::Plain(ref mut s) => s.flush(),
|
||||
#[cfg(feature = "native-tls")]
|
||||
MaybeTlsStream::NativeTls(ref mut s) => s.flush(),
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
MaybeTlsStream::Rustls(ref mut s) => s.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: NoDelay, T: NoDelay> NoDelay for Stream<S, T> {
|
||||
impl<S: Read + Write + NoDelay> NoDelay for MaybeTlsStream<S> {
|
||||
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
|
||||
match *self {
|
||||
Stream::Plain(ref mut s) => s.set_nodelay(nodelay),
|
||||
Stream::Tls(ref mut s) => s.set_nodelay(nodelay),
|
||||
MaybeTlsStream::Plain(ref mut s) => s.set_nodelay(nodelay),
|
||||
#[cfg(feature = "native-tls")]
|
||||
MaybeTlsStream::NativeTls(ref mut s) => s.set_nodelay(nodelay),
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
MaybeTlsStream::Rustls(ref mut s) => s.set_nodelay(nodelay),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,219 @@
|
|||
//! Connection helper.
|
||||
use std::io::{Read, Write};
|
||||
|
||||
use crate::{
|
||||
client::{client_with_config, uri_mode, IntoClientRequest},
|
||||
error::UrlError,
|
||||
handshake::client::Response,
|
||||
protocol::WebSocketConfig,
|
||||
stream::MaybeTlsStream,
|
||||
ClientHandshake, Error, HandshakeError, Result, WebSocket,
|
||||
};
|
||||
|
||||
/// A connector that can be used when establishing connections, allowing to control whether
|
||||
/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
|
||||
/// `Plain` variant.
|
||||
#[non_exhaustive]
|
||||
#[allow(missing_debug_implementations)]
|
||||
pub enum Connector {
|
||||
/// Plain (non-TLS) connector.
|
||||
Plain,
|
||||
/// `native-tls` TLS connector.
|
||||
#[cfg(feature = "native-tls")]
|
||||
NativeTls(native_tls_crate::TlsConnector),
|
||||
/// `rustls` TLS connector.
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
Rustls(std::sync::Arc<rustls::ClientConfig>),
|
||||
}
|
||||
|
||||
mod encryption {
|
||||
#[cfg(feature = "native-tls")]
|
||||
pub mod native_tls {
|
||||
use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
|
||||
|
||||
use std::io::{Read, Write};
|
||||
|
||||
use crate::{
|
||||
error::TlsError,
|
||||
stream::{MaybeTlsStream, Mode},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
pub fn wrap_stream<S>(
|
||||
socket: S,
|
||||
domain: &str,
|
||||
mode: Mode,
|
||||
tls_connector: Option<TlsConnector>,
|
||||
) -> Result<MaybeTlsStream<S>>
|
||||
where
|
||||
S: Read + Write,
|
||||
{
|
||||
match mode {
|
||||
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
|
||||
Mode::Tls => {
|
||||
let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
|
||||
let connector = try_connector.map_err(TlsError::Native)?;
|
||||
let connected = connector.connect(domain, socket);
|
||||
match connected {
|
||||
Err(e) => match e {
|
||||
TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
|
||||
TlsHandshakeError::WouldBlock(_) => {
|
||||
panic!("Bug: TLS handshake not blocked")
|
||||
}
|
||||
},
|
||||
Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
pub mod rustls {
|
||||
use rustls::{ClientConfig, ClientSession, StreamOwned};
|
||||
use webpki::DNSNameRef;
|
||||
|
||||
use std::{
|
||||
io::{Read, Write},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
error::TlsError,
|
||||
stream::{MaybeTlsStream, Mode},
|
||||
Result,
|
||||
};
|
||||
|
||||
pub fn wrap_stream<S>(
|
||||
socket: S,
|
||||
domain: &str,
|
||||
mode: Mode,
|
||||
tls_connector: Option<Arc<ClientConfig>>,
|
||||
) -> Result<MaybeTlsStream<S>>
|
||||
where
|
||||
S: Read + Write,
|
||||
{
|
||||
match mode {
|
||||
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
|
||||
Mode::Tls => {
|
||||
let config = match tls_connector {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
#[allow(unused_mut)]
|
||||
let mut config = ClientConfig::new();
|
||||
#[cfg(feature = "rustls-tls-native-roots")]
|
||||
{
|
||||
config.root_store = rustls_native_certs::load_native_certs()
|
||||
.map_err(|(_, err)| err)?;
|
||||
}
|
||||
#[cfg(feature = "rustls-tls-webpki-roots")]
|
||||
{
|
||||
config
|
||||
.root_store
|
||||
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
|
||||
}
|
||||
|
||||
Arc::new(config)
|
||||
}
|
||||
};
|
||||
let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?;
|
||||
let client = ClientSession::new(&config, domain);
|
||||
let stream = StreamOwned::new(client, socket);
|
||||
|
||||
Ok(MaybeTlsStream::Rustls(stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod plain {
|
||||
use std::io::{Read, Write};
|
||||
|
||||
use crate::{
|
||||
error::UrlError,
|
||||
stream::{MaybeTlsStream, Mode},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
|
||||
where
|
||||
S: Read + Write,
|
||||
{
|
||||
match mode {
|
||||
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
|
||||
Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
|
||||
|
||||
/// Creates a WebSocket handshake from a request and a stream,
|
||||
/// upgrading the stream to TLS if required.
|
||||
pub fn client_tls<R, S>(
|
||||
request: R,
|
||||
stream: S,
|
||||
) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
|
||||
where
|
||||
R: IntoClientRequest,
|
||||
S: Read + Write,
|
||||
{
|
||||
client_tls_with_config(request, stream, None, None)
|
||||
}
|
||||
|
||||
/// The same as [`client_tls()`] but one can specify a websocket configuration,
|
||||
/// and an optional connector. If no connector is specified, a default one will
|
||||
/// be created.
|
||||
///
|
||||
/// Please refer to [`client_tls()`] for more details.
|
||||
pub fn client_tls_with_config<R, S>(
|
||||
request: R,
|
||||
stream: S,
|
||||
config: Option<WebSocketConfig>,
|
||||
connector: Option<Connector>,
|
||||
) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
|
||||
where
|
||||
R: IntoClientRequest,
|
||||
S: Read + Write,
|
||||
{
|
||||
let request = request.into_client_request()?;
|
||||
|
||||
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
|
||||
let domain = match request.uri().host() {
|
||||
Some(d) => Ok(d.to_string()),
|
||||
None => Err(Error::Url(UrlError::NoHostName)),
|
||||
}?;
|
||||
|
||||
let mode = uri_mode(&request.uri())?;
|
||||
|
||||
let stream = match connector {
|
||||
Some(conn) => match conn {
|
||||
#[cfg(feature = "native-tls")]
|
||||
Connector::NativeTls(conn) => {
|
||||
self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
|
||||
}
|
||||
#[cfg(feature = "__rustls-tls")]
|
||||
Connector::Rustls(conn) => {
|
||||
self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
|
||||
}
|
||||
Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
|
||||
},
|
||||
None => {
|
||||
#[cfg(feature = "native-tls")]
|
||||
{
|
||||
self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
|
||||
}
|
||||
#[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
|
||||
{
|
||||
self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
|
||||
}
|
||||
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
|
||||
{
|
||||
self::encryption::plain::wrap_stream(stream, mode)
|
||||
}
|
||||
}
|
||||
}?;
|
||||
|
||||
client_with_config(request, stream, config)
|
||||
}
|
|
@ -1,10 +1,6 @@
|
|||
//! Verifies that the server returns a `ConnectionClosed` error when the connection
|
||||
//! is closed from the server's point of view and drop the underlying tcp socket.
|
||||
#![cfg(any(
|
||||
feature = "native-tls",
|
||||
feature = "rustls-tls-native-roots",
|
||||
feature = "rustls-tls-webpki-roots"
|
||||
))]
|
||||
#![cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
|
||||
|
||||
use std::{
|
||||
net::{TcpListener, TcpStream},
|
||||
|
@ -14,16 +10,10 @@ use std::{
|
|||
};
|
||||
|
||||
use net2::TcpStreamExt;
|
||||
use tungstenite::{accept, connect, stream::Stream, Error, Message, WebSocket};
|
||||
use tungstenite::{accept, connect, stream::MaybeTlsStream, Error, Message, WebSocket};
|
||||
use url::Url;
|
||||
|
||||
#[cfg(feature = "native-tls")]
|
||||
type Sock = WebSocket<Stream<TcpStream, native_tls_crate::TlsStream<TcpStream>>>;
|
||||
#[cfg(all(
|
||||
any(feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots"),
|
||||
not(feature = "native-tls")
|
||||
))]
|
||||
type Sock = WebSocket<Stream<TcpStream, rustls::StreamOwned<rustls::ClientSession, TcpStream>>>;
|
||||
type Sock = WebSocket<MaybeTlsStream<TcpStream>>;
|
||||
|
||||
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
|
||||
where
|
||||
|
|
Loading…
Reference in New Issue