StreamOwned: a Stream that owns its dependencies

This commit is contained in:
Joseph Birr-Pixton 2018-07-08 13:05:49 +01:00
parent d6f5c329fb
commit c6af00a225
3 changed files with 163 additions and 4 deletions

View File

@ -280,7 +280,7 @@ pub use msgs::enums::SignatureScheme;
pub use msgs::enums::CipherSuite;
pub use error::TLSError;
pub use session::Session;
pub use stream::Stream;
pub use stream::{Stream, StreamOwned};
pub use anchors::{DistinguishedNames, RootCertStore};
pub use client::StoresClientSessions;
pub use client::handy::{NoClientSessionStorage, ClientSessionMemoryCache};

View File

@ -75,14 +75,62 @@ impl<'a, S, T> Write for Stream<'a, S, T> where S: 'a + Session, T: 'a + Read +
}
}
/// This type implements `io::Read` and `io::Write`, encapsulating
/// and owning a Session `S` and an underlying blocking transport
/// `T`, such as a socket.
///
/// This allows you to use a rustls Session like a normal stream.
pub struct StreamOwned<S: Session + Sized, T: Read + Write + Sized> {
/// Our session
pub sess: S,
/// The underlying transport, like a socket
pub sock: T,
}
impl<S, T> StreamOwned<S, T> where S: Session, T: Read + Write {
/// Make a new StreamOwned taking the Session `sess` and socket-like
/// object `sock`. This does not fail and does no IO.
///
/// This is the same as `Stream::new` except `sess` and `sock` are
/// moved into the StreamOwned.
pub fn new(sess: S, sock: T) -> StreamOwned<S, T> {
StreamOwned { sess, sock }
}
}
impl<'a, S, T> StreamOwned<S, T> where S: Session, T: Read + Write {
fn as_stream(&'a mut self) -> Stream<'a, S, T> {
Stream { sess: &mut self.sess, sock: &mut self.sock }
}
}
impl<S, T> Read for StreamOwned<S, T> where S: Session, T: Read + Write {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
self.as_stream().read(buf)
}
}
impl<S, T> Write for StreamOwned<S, T> where S: Session, T: Read + Write {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
self.as_stream().write(buf)
}
fn flush(&mut self) -> Result<()> {
self.as_stream().flush()
}
}
#[cfg(test)]
mod tests {
use super::Stream;
use super::{Stream, StreamOwned};
use session::Session;
use client::ClientSession;
use server::ServerSession;
use std::net::TcpStream;
#[test]
fn session_can_be_instantiated_with() {
fn stream_can_be_created_for_session_and_tcpstream() {
fn _foo<'a>(sess: &'a mut Session, sock: &'a mut TcpStream) -> Stream<'a, Session, TcpStream> {
Stream {
sess,
@ -90,4 +138,24 @@ mod tests {
}
}
}
#[test]
fn streamowned_can_be_created_for_client_and_tcpstream() {
fn _foo(sess: ClientSession, sock: TcpStream) -> StreamOwned<ClientSession, TcpStream> {
StreamOwned {
sess,
sock,
}
}
}
#[test]
fn streamowned_can_be_created_for_server_and_tcpstream() {
fn _foo(sess: ServerSession, sock: TcpStream) -> StreamOwned<ServerSession, TcpStream> {
StreamOwned {
sess,
sock,
}
}
}
}

View File

@ -11,7 +11,7 @@ extern crate rustls;
use rustls::{ClientConfig, ClientSession, ResolvesClientCert};
use rustls::{ServerConfig, ServerSession, ResolvesServerCert};
use rustls::Session;
use rustls::Stream;
use rustls::{Stream, StreamOwned};
use rustls::{ProtocolVersion, SignatureScheme, CipherSuite};
use rustls::TLSError;
use rustls::sign;
@ -980,6 +980,20 @@ fn client_stream_write() {
}
}
#[test]
fn client_streamowned_write() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, mut server) = make_pair(*kt);
{
let pipe = OtherSession::new(&mut server);
let mut stream = StreamOwned::new(client, pipe);
assert_eq!(stream.write(b"hello").unwrap(), 5);
}
check_read(&mut server, b"hello");
}
}
#[test]
fn client_stream_read() {
for kt in ALL_KEY_TYPES.iter() {
@ -995,6 +1009,21 @@ fn client_stream_read() {
}
}
#[test]
fn client_streamowned_read() {
for kt in ALL_KEY_TYPES.iter() {
let (client, mut server) = make_pair(*kt);
server.write(b"world").unwrap();
{
let pipe = OtherSession::new(&mut server);
let mut stream = StreamOwned::new(client, pipe);
check_read(&mut stream, b"world");
}
}
}
#[test]
fn server_stream_write() {
for kt in ALL_KEY_TYPES.iter() {
@ -1009,6 +1038,20 @@ fn server_stream_write() {
}
}
#[test]
fn server_streamowned_write() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, server) = make_pair(*kt);
{
let pipe = OtherSession::new(&mut client);
let mut stream = StreamOwned::new(server, pipe);
assert_eq!(stream.write(b"hello").unwrap(), 5);
}
check_read(&mut client, b"hello");
}
}
#[test]
fn server_stream_read() {
for kt in ALL_KEY_TYPES.iter() {
@ -1024,6 +1067,21 @@ fn server_stream_read() {
}
}
#[test]
fn server_streamowned_read() {
for kt in ALL_KEY_TYPES.iter() {
let (mut client, server) = make_pair(*kt);
client.write(b"world").unwrap();
{
let pipe = OtherSession::new(&mut client);
let mut stream = StreamOwned::new(server, pipe);
check_read(&mut stream, b"world");
}
}
}
fn make_disjoint_suite_configs() -> (ClientConfig, ServerConfig) {
let kt = KeyType::RSA;
let mut server_config = make_server_config(kt);
@ -1052,6 +1110,23 @@ fn client_stream_handshake_error() {
}
}
#[test]
fn client_streamowned_handshake_error() {
let (client_config, server_config) = make_disjoint_suite_configs();
let (client, mut server) = make_pair_for_configs(client_config, server_config);
let pipe = OtherSession::new_fails(&mut server);
let mut client_stream = StreamOwned::new(client, pipe);
let rc = client_stream.write(b"hello");
assert!(rc.is_err());
assert_eq!(format!("{:?}", rc),
"Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })");
let rc = client_stream.write(b"hello");
assert!(rc.is_err());
assert_eq!(format!("{:?}", rc),
"Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })");
}
#[test]
fn server_stream_handshake_error() {
let (client_config, server_config) = make_disjoint_suite_configs();
@ -1070,6 +1145,22 @@ fn server_stream_handshake_error() {
}
}
#[test]
fn server_streamowned_handshake_error() {
let (client_config, server_config) = make_disjoint_suite_configs();
let (mut client, server) = make_pair_for_configs(client_config, server_config);
client.write(b"world").unwrap();
let pipe = OtherSession::new_fails(&mut client);
let mut server_stream = StreamOwned::new(server, pipe);
let mut bytes = [0u8; 5];
let rc = server_stream.read(&mut bytes);
assert!(rc.is_err());
assert_eq!(format!("{:?}", rc),
"Err(Custom { kind: InvalidData, error: PeerIncompatibleError(\"no ciphersuites in common\") })");
}
#[test]
fn server_config_is_clone() {
let _ = make_server_config(KeyType::RSA).clone();