From ae70e4a9e1842e1381593cafccc0309190a574c0 Mon Sep 17 00:00:00 2001 From: Joseph Birr-Pixton Date: Mon, 28 May 2018 19:05:16 +0100 Subject: [PATCH] Add support for vectored IO This is abstract: behind a trivial rustls-specific trait so it can be tested and doesn't rely on implementation details of vecio. --- Cargo.toml | 1 + README.md | 2 + examples/tlsclient.rs | 6 +- examples/tlsserver.rs | 7 +- examples/util/mod.rs | 20 +++++ src/client/mod.rs | 5 ++ src/lib.rs | 1 + src/server/mod.rs | 6 ++ src/session.rs | 11 ++- src/vecbuf.rs | 59 +++++++++++--- tests/api.rs | 174 +++++++++++++++++++++++++++++++++++++++++- 11 files changed, 274 insertions(+), 18 deletions(-) create mode 100644 examples/util/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 66cfd2dc..edcf56d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ serde_derive = "1.0" webpki-roots = "0.14.0" ct-logs = "0.3" regex = "0.2" +vecio = "0.1" [[example]] name = "bogo_shim" diff --git a/README.md b/README.md index e6d4e016..1ba2fc29 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ Rustls is currently in development and hence unstable. [Here's what I'm working - Add support for `SSLKEYLOGFILE`; not enabled by default. - Add support for basic usage in QUIC. - `ServerConfig::set_single_cert` and company now report errors. + - Add support for vectored IO: `writev_tls` can now be used to + optimise system call usage. * 0.12.0 (2018-01-06): - New API for learning negotiated cipher suite. - Move TLS1.3 support from draft 18 to 22. diff --git a/examples/tlsclient.rs b/examples/tlsclient.rs index e2a6895a..32d2b898 100644 --- a/examples/tlsclient.rs +++ b/examples/tlsclient.rs @@ -22,6 +22,10 @@ extern crate rustls; extern crate webpki; extern crate webpki_roots; extern crate ct_logs; +extern crate vecio; + +mod util; +use util::WriteVAdapter; use rustls::Session; @@ -144,7 +148,7 @@ impl TlsClient { } fn do_write(&mut self) { - self.tls_session.write_tls(&mut self.socket).unwrap(); + self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket)).unwrap(); } fn register(&self, poll: &mut mio::Poll) { diff --git a/examples/tlsserver.rs b/examples/tlsserver.rs index d28d789e..0dac5999 100644 --- a/examples/tlsserver.rs +++ b/examples/tlsserver.rs @@ -18,12 +18,15 @@ extern crate docopt; use docopt::Docopt; extern crate env_logger; - +extern crate vecio; extern crate rustls; use rustls::{RootCertStore, Session, NoClientAuth, AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient}; +mod util; +use util::WriteVAdapter; + // Token for our listening socket. const LISTENER: mio::Token = mio::Token(0); @@ -301,7 +304,7 @@ impl Connection { } fn do_tls_write(&mut self) { - let rc = self.tls_session.write_tls(&mut self.socket); + let rc = self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket)); if rc.is_err() { error!("write failed {:?}", rc); self.closing = true; diff --git a/examples/util/mod.rs b/examples/util/mod.rs new file mode 100644 index 00000000..175732c8 --- /dev/null +++ b/examples/util/mod.rs @@ -0,0 +1,20 @@ +use std::io; +use vecio::Rawv; +use rustls; + +/// This glues our `rustls::WriteV` trait to `vecio::Rawv`. +pub struct WriteVAdapter<'a> { + rawv: &'a mut Rawv +} + +impl<'a> WriteVAdapter<'a> { + pub fn new(rawv: &'a mut Rawv) -> WriteVAdapter<'a> { + WriteVAdapter { rawv } + } +} + +impl<'a> rustls::WriteV for WriteVAdapter<'a> { + fn writev(&mut self, bytes: &[&[u8]]) -> io::Result { + self.rawv.writev(bytes) + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index ff6126c2..38aae75c 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -13,6 +13,7 @@ use anchors; use sign; use error::TLSError; use key; +use vecbuf::WriteV; use std::sync::Arc; use std::io; @@ -463,6 +464,10 @@ impl Session for ClientSession { self.imp.common.write_tls(wr) } + fn writev_tls(&mut self, wr: &mut WriteV) -> io::Result { + self.imp.common.writev_tls(wr) + } + fn process_new_packets(&mut self) -> Result<(), TLSError> { self.imp.process_new_packets() } diff --git a/src/lib.rs b/src/lib.rs index 9e573b38..5f2d8fd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -301,6 +301,7 @@ pub use verify::{NoClientAuth, AllowAnyAuthenticatedClient, pub use suites::{ALL_CIPHERSUITES, SupportedCipherSuite}; pub use key::{Certificate, PrivateKey}; pub use keylog::{KeyLog, NoKeyLog, KeyLogFile}; +pub use vecbuf::WriteV; /// Message signing interfaces and implementations. pub mod sign; diff --git a/src/server/mod.rs b/src/server/mod.rs index ddd7748f..e8d12f15 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -9,6 +9,8 @@ use error::TLSError; use sign; use verify; use key; +use vecbuf::WriteV; + use webpki; use std::sync::Arc; @@ -456,6 +458,10 @@ impl Session for ServerSession { self.imp.common.write_tls(wr) } + fn writev_tls(&mut self, wr: &mut WriteV) -> io::Result { + self.imp.common.writev_tls(wr) + } + fn process_new_packets(&mut self) -> Result<(), TLSError> { self.imp.process_new_packets() } diff --git a/src/session.rs b/src/session.rs index fc8be9b3..90075647 100644 --- a/src/session.rs +++ b/src/session.rs @@ -11,7 +11,7 @@ use msgs::enums::KeyUpdateRequest; use error::TLSError; use suites::SupportedCipherSuite; use cipher::{MessageDecrypter, MessageEncrypter, self}; -use vecbuf::ChunkVecBuffer; +use vecbuf::{ChunkVecBuffer, WriteV}; use key; use key_schedule::{SecretKind, KeySchedule}; use prf; @@ -51,6 +51,11 @@ pub trait Session: quic::QuicExt + Read + Write + Send + Sync { /// [`wants_write`]: #tymethod.wants_write fn write_tls(&mut self, wr: &mut Write) -> Result; + /// Like `write_tls`, but writes potentially many records in one + /// go via `wr`; a `rustls::WriteV`. This function has the same semantics + /// as `write_tls` otherwise. + fn writev_tls(&mut self, wr: &mut WriteV) -> Result; + /// Processes any new packets read by a previous call to `read_tls`. /// Errors from this function relate to TLS protocol errors, and /// are fatal to the session. Future calls after an error will do @@ -638,6 +643,10 @@ impl SessionCommon { self.sendable_tls.write_to(wr) } + pub fn writev_tls(&mut self, wr: &mut WriteV) -> io::Result { + self.sendable_tls.writev_to(wr) + } + /// Send plaintext application data, fragmenting and /// encrypting it as it goes out. /// diff --git a/src/vecbuf.rs b/src/vecbuf.rs index e48add57..21d466ef 100644 --- a/src/vecbuf.rs +++ b/src/vecbuf.rs @@ -3,6 +3,25 @@ use std::io; use std::cmp; use std::collections::VecDeque; +/// This trait specifies rustls's precise requirements doing writes with +/// vectored IO. +/// +/// The purpose of vectored IO is to pass contigious output in many blocks +/// to the kernel without either coalescing it in user-mode (by allocating +/// and copying) or making many system calls. +/// +/// We don't directly use types from the vecio crate because the traits +/// don't compose well: the most useful trait (`Rawv`) is hard to test +/// with (it can't be implemented without an FD) and implies a readable +/// source too. You will have to write a trivial adaptor struct which +/// glues either `vecio::Rawv` or `vecio::Writev` to this trait. See +/// the rustls examples. +pub trait WriteV { + /// Writes as much data from `vbytes` as possible, returning + /// the number of bytes written. + fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result; +} + /// This is a byte buffer that is built from a vector /// of byte vectors. This avoids extra copies when /// appending a new byte vector, at the expense of @@ -87,33 +106,49 @@ impl ChunkVecBuffer { while offs < buf.len() && !self.is_empty() { let used = self.chunks[0].as_slice().read(&mut buf[offs..])?; - if used == self.chunks[0].len() { - self.take_one(); - } else { - self.chunks[0] = self.chunks[0].split_off(used); - } - + self.consume(used); offs += used; } Ok(offs) } - /// Read data of this object, passing it `wr` + fn consume(&mut self, mut used: usize) { + while used > 0 && !self.is_empty() { + if used >= self.chunks[0].len() { + used -= self.chunks[0].len(); + self.take_one(); + } else { + self.chunks[0] = self.chunks[0].split_off(used); + used = 0; + } + } + } + + /// Read data out of this object, passing it `wr` pub fn write_to(&mut self, wr: &mut io::Write) -> io::Result { - // would desperately like writev support here! if self.is_empty() { return Ok(0); } let used = wr.write(&self.chunks[0])?; + self.consume(used); + Ok(used) + } - if used == self.chunks[0].len() { - self.take_one(); - } else { - self.chunks[0] = self.chunks[0].split_off(used); + pub fn writev_to(&mut self, wr: &mut WriteV) -> io::Result { + if self.is_empty() { + return Ok(0); } + let used = { + let chunks = self.chunks.iter() + .map(|ch| ch.as_ref()) + .collect::>(); + + wr.writev(&chunks)? + }; + self.consume(used); Ok(used) } } diff --git a/tests/api.rs b/tests/api.rs index e73052ad..e5ef493b 100644 --- a/tests/api.rs +++ b/tests/api.rs @@ -692,17 +692,29 @@ struct OtherSession<'a> { sess: &'a mut Session, pub reads: usize, pub writes: usize, + pub writevs: Vec>, fail_ok: bool, + pub short_writes: bool, pub last_error: Option, } impl<'a> OtherSession<'a> { fn new(sess: &'a mut Session) -> OtherSession<'a> { - OtherSession { sess, reads: 0, writes: 0, fail_ok: false, last_error: None, } + OtherSession { + sess, + reads: 0, + writes: 0, + writevs: vec![], + fail_ok: false, + short_writes: false, + last_error: None, + } } fn new_fails(sess: &'a mut Session) -> OtherSession<'a> { - OtherSession { sess, reads: 0, writes: 0, fail_ok: true, last_error: None, } + let mut os = OtherSession::new(sess); + os.fail_ok = true; + os } } @@ -733,6 +745,37 @@ impl<'a> io::Write for OtherSession<'a> { } } +impl<'a> rustls::WriteV for OtherSession<'a> { + fn writev(&mut self, b: &[&[u8]]) -> io::Result { + let mut total = 0; + let mut lengths = vec![]; + for bytes in b { + let write_len = if self.short_writes { + if bytes.len() > 5 { bytes.len() / 2 } else { bytes.len() } + } else { + bytes.len() + }; + + let l = self.sess.read_tls(&mut io::Cursor::new(&bytes[..write_len]))?; + lengths.push(l); + total += l; + if bytes.len() != l { + break; + } + } + + let rc = self.sess.process_new_packets(); + if !self.fail_ok { + rc.unwrap(); + } else if rc.is_err() { + self.last_error = rc.err(); + } + + self.writevs.push(lengths); + Ok(total) + } +} + #[test] fn client_complete_io_for_handshake() { let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost")); @@ -1356,3 +1399,130 @@ fn key_log_for_tls13() { assert_eq!(client_resume_log[3], server_resume_log[3]); assert_eq!(client_resume_log[4], server_resume_log[4]); } + +#[test] +fn vectored_write_for_server_appdata() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost")); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + do_handshake(&mut client, &mut server); + + server.write(b"01234567890123456789").unwrap(); + server.write(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut client); + let wrlen = server.writev_tls(&mut pipe).unwrap(); + assert_eq!(84, wrlen); + assert_eq!(pipe.writevs, vec![vec![42, 42]]); + } + check_read(&mut client, b"0123456789012345678901234567890123456789"); +} + +#[test] +fn vectored_write_for_client_appdata() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost")); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + do_handshake(&mut client, &mut server); + + client.write(b"01234567890123456789").unwrap(); + client.write(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut server); + let wrlen = client.writev_tls(&mut pipe).unwrap(); + assert_eq!(84, wrlen); + assert_eq!(pipe.writevs, vec![vec![42, 42]]); + } + check_read(&mut server, b"0123456789012345678901234567890123456789"); +} + +#[test] +fn vectored_write_for_server_handshake() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost")); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + server.write(b"01234567890123456789").unwrap(); + server.write(b"0123456789").unwrap(); + + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); + { + let mut pipe = OtherSession::new(&mut client); + let wrlen = server.writev_tls(&mut pipe).unwrap(); + // don't assert exact sizes here, to avoid a brittle test + assert!(wrlen > 5000); // its pretty big (contains cert chain) + assert_eq!(pipe.writevs.len(), 1); // only one writev + assert!(pipe.writevs[0].len() > 3); // at least a server hello/cert/serverkx + } + + client.process_new_packets().unwrap(); + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); + { + let mut pipe = OtherSession::new(&mut client); + let wrlen = server.writev_tls(&mut pipe).unwrap(); + assert_eq!(wrlen, 74); + assert_eq!(pipe.writevs, vec![vec![42, 32]]); + } + + assert_eq!(server.is_handshaking(), false); + assert_eq!(client.is_handshaking(), false); + check_read(&mut client, b"012345678901234567890123456789"); +} + +#[test] +fn vectored_write_for_client_handshake() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost")); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + client.write(b"01234567890123456789").unwrap(); + client.write(b"0123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut server); + let wrlen = client.writev_tls(&mut pipe).unwrap(); + // don't assert exact sizes here, to avoid a brittle test + assert!(wrlen > 200); // just the client hello + assert_eq!(pipe.writevs.len(), 1); // only one writev + assert!(pipe.writevs[0].len() == 1); // only a client hello + } + + transfer(&mut server, &mut client); + client.process_new_packets().unwrap(); + + { + let mut pipe = OtherSession::new(&mut server); + let wrlen = client.writev_tls(&mut pipe).unwrap(); + assert_eq!(wrlen, 138); + // CCS, finished, then two application datas + assert_eq!(pipe.writevs, vec![vec![6, 58, 42, 32]]); + } + + assert_eq!(server.is_handshaking(), false); + assert_eq!(client.is_handshaking(), false); + check_read(&mut server, b"012345678901234567890123456789"); +} + +#[test] +fn vectored_write_with_slow_client() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost")); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + client.set_buffer_limit(32); + + do_handshake(&mut client, &mut server); + server.write(b"01234567890123456789").unwrap(); + + { + let mut pipe = OtherSession::new(&mut client); + pipe.short_writes = true; + let wrlen = server.writev_tls(&mut pipe).unwrap() + + server.writev_tls(&mut pipe).unwrap() + + server.writev_tls(&mut pipe).unwrap() + + server.writev_tls(&mut pipe).unwrap() + + server.writev_tls(&mut pipe).unwrap() + + server.writev_tls(&mut pipe).unwrap(); + assert_eq!(42, wrlen); + assert_eq!(pipe.writevs, vec![vec![21], vec![10], vec![5], vec![3], vec![3]]); + } + check_read(&mut client, b"01234567890123456789"); +}