diff --git a/README.md b/README.md index 089d7041..fbd07bd5 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ If you'd like to help out, please see [CONTRIBUTING.md](CONTRIBUTING.md). - Performance improvements. Thanks to @nviennot. - Fixed client authentication being unduly rejected by client when server uses the superseded certificate_types field of CertificateRequest. + - *Breaking API change*: The writev_tls API has been removed, in favour + of using vectored IO support now offered by std::io::Write. * 0.17.0 (2020-02-22): - *Breaking API change*: ALPN protocols offered by the client are passed to the server certificate resolution trait (`ResolvesServerCert`). diff --git a/rustls-mio/Cargo.toml b/rustls-mio/Cargo.toml index 02859b7a..d24e0f3b 100644 --- a/rustls-mio/Cargo.toml +++ b/rustls-mio/Cargo.toml @@ -28,7 +28,6 @@ regex = "1.0" serde = "1.0" serde_derive = "1.0" tempfile = "3.0" -vecio = "0.1" webpki-roots = "0.19" [[example]] diff --git a/rustls-mio/examples/tlsclient.rs b/rustls-mio/examples/tlsclient.rs index f91180f5..daeaca9c 100644 --- a/rustls-mio/examples/tlsclient.rs +++ b/rustls-mio/examples/tlsclient.rs @@ -23,9 +23,6 @@ use webpki; use webpki_roots; use ct_logs; - -mod util; - use rustls::Session; const CLIENT: mio::Token = mio::Token(0); @@ -146,17 +143,10 @@ impl TlsClient { } } - #[cfg(target_os = "windows")] fn do_write(&mut self) { self.tls_session.write_tls(&mut self.socket).unwrap(); } - #[cfg(not(target_os = "windows"))] - fn do_write(&mut self) { - use crate::util::WriteVAdapter; - self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket)).unwrap(); - } - fn register(&mut self, registry: &mio::Registry) { let interest = self.ready_interest(); registry.register(&mut self.socket, CLIENT, interest).unwrap(); diff --git a/rustls-mio/examples/tlsserver.rs b/rustls-mio/examples/tlsserver.rs index 3617955f..5d45fa56 100644 --- a/rustls-mio/examples/tlsserver.rs +++ b/rustls-mio/examples/tlsserver.rs @@ -24,8 +24,6 @@ use rustls; use rustls::{RootCertStore, Session, NoClientAuth, AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient}; -mod util; - // Token for our listening socket. const LISTENER: mio::Token = mio::Token(0); @@ -310,17 +308,10 @@ impl Connection { } } - #[cfg(target_os = "windows")] fn tls_write(&mut self) -> io::Result { self.tls_session.write_tls(&mut self.socket) } - #[cfg(not(target_os = "windows"))] - fn tls_write(&mut self) -> io::Result { - use crate::util::WriteVAdapter; - self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket)) - } - fn do_tls_write_and_handle_error(&mut self) { let rc = self.tls_write(); if rc.is_err() { diff --git a/rustls-mio/examples/util/mod.rs b/rustls-mio/examples/util/mod.rs deleted file mode 100644 index c8a6263d..00000000 --- a/rustls-mio/examples/util/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::io; -use vecio::Rawv; -use rustls; - -/// This glues our `rustls::WriteV` trait to `vecio::Rawv`. -pub struct WriteVAdapter<'a> { - rawv: &'a mut dyn Rawv -} - -impl<'a> WriteVAdapter<'a> { - pub fn new(rawv: &'a mut dyn Rawv) -> WriteVAdapter<'a> { - WriteVAdapter { rawv } - } -} - -impl<'a> rustls::WriteV for WriteVAdapter<'a> { - fn writev(&mut self, bytes: &[&[u8]]) -> io::Result { - self.rawv.writev(bytes) - } -} diff --git a/rustls/src/client/mod.rs b/rustls/src/client/mod.rs index 093513bf..6c086deb 100644 --- a/rustls/src/client/mod.rs +++ b/rustls/src/client/mod.rs @@ -13,7 +13,6 @@ use crate::anchors; use crate::sign; use crate::error::TLSError; use crate::key; -use crate::vecbuf::WriteV; #[cfg(feature = "logging")] use crate::log::trace; @@ -666,10 +665,6 @@ impl Session for ClientSession { self.imp.common.write_tls(wr) } - fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result { - self.imp.common.writev_tls(wr) - } - fn process_new_packets(&mut self) -> Result<(), TLSError> { self.imp.process_new_packets() } diff --git a/rustls/src/lib.rs b/rustls/src/lib.rs index bd2dfe6e..6168ca7f 100644 --- a/rustls/src/lib.rs +++ b/rustls/src/lib.rs @@ -280,7 +280,6 @@ pub use crate::verify::{NoClientAuth, AllowAnyAuthenticatedClient, pub use crate::suites::{ALL_CIPHERSUITES, BulkAlgorithm, SupportedCipherSuite}; pub use crate::key::{Certificate, PrivateKey}; pub use crate::keylog::{KeyLog, NoKeyLog, KeyLogFile}; -pub use crate::vecbuf::{WriteV, WriteVAdapter}; /// Message signing interfaces and implementations. pub mod sign; diff --git a/rustls/src/server/mod.rs b/rustls/src/server/mod.rs index ea88180a..8c6906e9 100644 --- a/rustls/src/server/mod.rs +++ b/rustls/src/server/mod.rs @@ -10,7 +10,6 @@ use crate::error::TLSError; use crate::sign; use crate::verify; use crate::key; -use crate::vecbuf::WriteV; #[cfg(feature = "logging")] use crate::log::trace; @@ -577,10 +576,6 @@ impl Session for ServerSession { self.imp.common.write_tls(wr) } - fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result { - self.imp.common.writev_tls(wr) - } - fn process_new_packets(&mut self) -> Result<(), TLSError> { self.imp.process_new_packets() } diff --git a/rustls/src/session.rs b/rustls/src/session.rs index a73cdfd1..34024b6f 100644 --- a/rustls/src/session.rs +++ b/rustls/src/session.rs @@ -10,7 +10,7 @@ use crate::msgs::enums::{ContentType, ProtocolVersion, AlertDescription, AlertLe use crate::error::TLSError; use crate::suites::SupportedCipherSuite; use crate::cipher; -use crate::vecbuf::{ChunkVecBuffer, WriteV}; +use crate::vecbuf::ChunkVecBuffer; use crate::key; use crate::prf; use crate::rand; @@ -52,11 +52,6 @@ pub trait Session: quic::QuicExt + Read + Write + Send + Sync { /// [`wants_write`]: #tymethod.wants_write fn write_tls(&mut self, wr: &mut dyn Write) -> Result; - /// Like `write_tls`, but writes potentially many records in one - /// go via `wr`; a `rustls::WriteV`. This function has the same semantics - /// as `write_tls` otherwise. - fn writev_tls(&mut self, wr: &mut dyn WriteV) -> Result; - /// Processes any new packets read by a previous call to `read_tls`. /// Errors from this function relate to TLS protocol errors, and /// are fatal to the session. Future calls after an error will do @@ -591,10 +586,6 @@ impl SessionCommon { self.sendable_tls.write_to(wr) } - pub fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result { - self.sendable_tls.writev_to(wr) - } - /// Send plaintext application data, fragmenting and /// encrypting it as it goes out. /// diff --git a/rustls/src/vecbuf.rs b/rustls/src/vecbuf.rs index b587d5f9..0989c949 100644 --- a/rustls/src/vecbuf.rs +++ b/rustls/src/vecbuf.rs @@ -2,26 +2,6 @@ use std::io::Read; use std::io; use std::cmp; use std::collections::VecDeque; -use std::convert; - -/// This trait specifies rustls's precise requirements doing writes with -/// vectored IO. -/// -/// The purpose of vectored IO is to pass contigious output in many blocks -/// to the kernel without either coalescing it in user-mode (by allocating -/// and copying) or making many system calls. -/// -/// We don't directly use types from the vecio crate because the traits -/// don't compose well: the most useful trait (`Rawv`) is hard to test -/// with (it can't be implemented without an FD) and implies a readable -/// source too. You will have to write a trivial adaptor struct which -/// glues either `vecio::Rawv` or `vecio::Writev` to this trait. See -/// the rustls examples. -pub trait WriteV { - /// Writes as much data from `vbytes` as possible, returning - /// the number of bytes written. - fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result; -} /// This is a byte buffer that is built from a vector /// of byte vectors. This avoids extra copies when @@ -132,50 +112,12 @@ impl ChunkVecBuffer { return Ok(0); } - let used = wr.write(&self.chunks[0])?; + let used = wr.write_vectored(&self.chunks.iter() + .map(|ch| io::IoSlice::new(ch)) + .collect::>())?; self.consume(used); Ok(used) } - - pub fn writev_to(&mut self, wr: &mut dyn WriteV) -> io::Result { - if self.is_empty() { - return Ok(0); - } - - let used = { - let chunks = self.chunks.iter() - .map(convert::AsRef::as_ref) - .collect::>(); - - wr.writev(&chunks)? - }; - self.consume(used); - Ok(used) - } -} - -/// This is a simple wrapper around an object -/// which implements `std::io::Write` in order to autoimplement `WriteV`. -/// It uses the `write_vectored` method from `std::io::Write` in order -/// to do this. -pub struct WriteVAdapter(T); - -impl WriteVAdapter { - /// build an adapter from a Write object - pub fn new(inner: T) -> Self { - WriteVAdapter(inner) - } -} - -impl WriteV for WriteVAdapter { - fn writev(&mut self, buffers: &[&[u8]]) -> io::Result { - self.0.write_vectored( - &buffers - .iter() - .map(|b| io::IoSlice::new(b)) - .collect::>(), - ) - } } #[cfg(test)] diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 6af56491..8f006f37 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -6,7 +6,7 @@ use std::mem; use std::fmt; use std::env; use std::error::Error; -use std::io::{self, Write, Read}; +use std::io::{self, Write, Read, IoSlice}; use rustls; @@ -842,6 +842,23 @@ fn server_respects_buffer_limit_pre_handshake() { check_read(&mut client, b"01234567890123456789012345678901"); } +#[test] +fn server_respects_buffer_limit_pre_handshake_with_vectored_write() { + let (mut client, mut server) = make_pair(KeyType::RSA); + + server.set_buffer_limit(32); + + assert_eq!(server.write_vectored(&[IoSlice::new(b"01234567890123456789"), + IoSlice::new(b"01234567890123456789")]).unwrap(), + 32); + + do_handshake(&mut client, &mut server); + transfer(&mut server, &mut client); + client.process_new_packets().unwrap(); + + check_read(&mut client, b"01234567890123456789012345678901"); +} + #[test] fn server_respects_buffer_limit_post_handshake() { let (mut client, mut server) = make_pair(KeyType::RSA); @@ -875,6 +892,23 @@ fn client_respects_buffer_limit_pre_handshake() { check_read(&mut server, b"01234567890123456789012345678901"); } +#[test] +fn client_respects_buffer_limit_pre_handshake_with_vectored_write() { + let (mut client, mut server) = make_pair(KeyType::RSA); + + client.set_buffer_limit(32); + + assert_eq!(client.write_vectored(&[IoSlice::new(b"01234567890123456789"), + IoSlice::new(b"01234567890123456789")]).unwrap(), + 32); + + do_handshake(&mut client, &mut server); + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); + + check_read(&mut server, b"01234567890123456789012345678901"); +} + #[test] fn client_respects_buffer_limit_post_handshake() { let (mut client, mut server) = make_pair(KeyType::RSA); @@ -894,7 +928,6 @@ fn client_respects_buffer_limit_post_handshake() { struct OtherSession<'a> { sess: &'a mut dyn Session, pub reads: usize, - pub writes: usize, pub writevs: Vec>, fail_ok: bool, pub short_writes: bool, @@ -906,7 +939,6 @@ impl<'a> OtherSession<'a> { OtherSession { sess, reads: 0, - writes: 0, writevs: vec![], fail_ok: false, short_writes: false, @@ -929,27 +961,15 @@ impl<'a> io::Read for OtherSession<'a> { } impl<'a> io::Write for OtherSession<'a> { - fn write(&mut self, mut b: &[u8]) -> io::Result { - self.writes += 1; - let l = self.sess.read_tls(b.by_ref())?; - let rc = self.sess.process_new_packets(); - - if !self.fail_ok { - rc.unwrap(); - } else if rc.is_err() { - self.last_error = rc.err(); - } - - Ok(l) + fn write(&mut self, _: &[u8]) -> io::Result { + unreachable!() } fn flush(&mut self) -> io::Result<()> { Ok(()) } -} -impl<'a> rustls::WriteV for OtherSession<'a> { - fn writev(&mut self, b: &[&[u8]]) -> io::Result { + fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result { let mut total = 0; let mut lengths = vec![]; for bytes in b { @@ -1012,7 +1032,8 @@ fn client_complete_io_for_write() { let mut pipe = OtherSession::new(&mut server); let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap(); assert!(rdlen == 0 && wrlen > 0); - assert_eq!(pipe.writes, 2); + println!("{:?}", pipe.writevs); + assert_eq!(pipe.writevs, vec![ vec![ 42, 42 ] ]); } check_read(&mut server, b"0123456789012345678901234567890123456789"); } @@ -1071,7 +1092,7 @@ fn server_complete_io_for_write() { let mut pipe = OtherSession::new(&mut client); let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap(); assert!(rdlen == 0 && wrlen > 0); - assert_eq!(pipe.writes, 2); + assert_eq!(pipe.writevs, vec![ vec![ 42, 42 ] ]); } check_read(&mut client, b"0123456789012345678901234567890123456789"); } @@ -1774,7 +1795,7 @@ fn vectored_write_for_server_appdata() { server.write(b"01234567890123456789").unwrap(); { let mut pipe = OtherSession::new(&mut client); - let wrlen = server.writev_tls(&mut pipe).unwrap(); + let wrlen = server.write_tls(&mut pipe).unwrap(); assert_eq!(84, wrlen); assert_eq!(pipe.writevs, vec![vec![42, 42]]); } @@ -1790,7 +1811,7 @@ fn vectored_write_for_client_appdata() { client.write(b"01234567890123456789").unwrap(); { let mut pipe = OtherSession::new(&mut server); - let wrlen = client.writev_tls(&mut pipe).unwrap(); + let wrlen = client.write_tls(&mut pipe).unwrap(); assert_eq!(84, wrlen); assert_eq!(pipe.writevs, vec![vec![42, 42]]); } @@ -1808,7 +1829,7 @@ fn vectored_write_for_server_handshake() { server.process_new_packets().unwrap(); { let mut pipe = OtherSession::new(&mut client); - let wrlen = server.writev_tls(&mut pipe).unwrap(); + let wrlen = server.write_tls(&mut pipe).unwrap(); // don't assert exact sizes here, to avoid a brittle test assert!(wrlen > 4000); // its pretty big (contains cert chain) assert_eq!(pipe.writevs.len(), 1); // only one writev @@ -1820,7 +1841,7 @@ fn vectored_write_for_server_handshake() { server.process_new_packets().unwrap(); { let mut pipe = OtherSession::new(&mut client); - let wrlen = server.writev_tls(&mut pipe).unwrap(); + let wrlen = server.write_tls(&mut pipe).unwrap(); assert_eq!(wrlen, 177); assert_eq!(pipe.writevs, vec![vec![103, 42, 32]]); } @@ -1838,7 +1859,7 @@ fn vectored_write_for_client_handshake() { client.write(b"0123456789").unwrap(); { let mut pipe = OtherSession::new(&mut server); - let wrlen = client.writev_tls(&mut pipe).unwrap(); + let wrlen = client.write_tls(&mut pipe).unwrap(); // don't assert exact sizes here, to avoid a brittle test assert!(wrlen > 200); // just the client hello assert_eq!(pipe.writevs.len(), 1); // only one writev @@ -1850,7 +1871,7 @@ fn vectored_write_for_client_handshake() { { let mut pipe = OtherSession::new(&mut server); - let wrlen = client.writev_tls(&mut pipe).unwrap(); + let wrlen = client.write_tls(&mut pipe).unwrap(); assert_eq!(wrlen, 138); // CCS, finished, then two application datas assert_eq!(pipe.writevs, vec![vec![6, 58, 42, 32]]); @@ -1873,12 +1894,12 @@ fn vectored_write_with_slow_client() { { let mut pipe = OtherSession::new(&mut client); pipe.short_writes = true; - let wrlen = server.writev_tls(&mut pipe).unwrap() + - server.writev_tls(&mut pipe).unwrap() + - server.writev_tls(&mut pipe).unwrap() + - server.writev_tls(&mut pipe).unwrap() + - server.writev_tls(&mut pipe).unwrap() + - server.writev_tls(&mut pipe).unwrap(); + let wrlen = server.write_tls(&mut pipe).unwrap() + + server.write_tls(&mut pipe).unwrap() + + server.write_tls(&mut pipe).unwrap() + + server.write_tls(&mut pipe).unwrap() + + server.write_tls(&mut pipe).unwrap() + + server.write_tls(&mut pipe).unwrap(); assert_eq!(42, wrlen); assert_eq!(pipe.writevs, vec![vec![21], vec![10], vec![5], vec![3], vec![3]]); } @@ -2201,21 +2222,30 @@ fn test_client_does_not_offer_sha1() { #[test] fn test_client_mtu_reduction() { - fn collect_write_lengths(client: &mut ClientSession) -> Vec { - let mut r = Vec::new(); - let mut buf = [0u8; 128]; + struct CollectWrites { + writevs: Vec>, + } - loop { - let sz = client.write_tls(&mut buf.as_mut()) - .unwrap(); - r.push(sz); - assert!(sz <= 64); - if sz < 64 { - break; - } + impl io::Write for CollectWrites { + fn write(&mut self, _: &[u8]) -> io::Result { panic!() } + fn flush(&mut self) -> io::Result<()> { panic!() } + fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result { + let writes = b.iter() + .map(|slice| slice.len()) + .collect::>(); + let len = writes.iter().sum(); + self.writevs.push(writes); + Ok(len) } + } - r + fn collect_write_lengths(client: &mut ClientSession) -> Vec { + let mut collector = CollectWrites { writevs: vec![] }; + + client.write_tls(&mut collector) + .unwrap(); + assert_eq!(collector.writevs.len(), 1); + collector.writevs[0].clone() } for kt in ALL_KEY_TYPES.iter() { @@ -2224,6 +2254,7 @@ fn test_client_mtu_reduction() { let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost")); let writes = collect_write_lengths(&mut client); + println!("writes at mtu=64: {:?}", writes); assert!(writes.iter().all(|x| *x <= 64)); assert!(writes.len() > 1); }