Add support for vectored IO

This is abstract: behind a trivial rustls-specific trait so it
can be tested and doesn't rely on implementation details of vecio.
This commit is contained in:
Joseph Birr-Pixton 2018-05-28 19:05:16 +01:00
parent c9099c5b26
commit ae70e4a9e1
11 changed files with 274 additions and 18 deletions

View File

@ -33,6 +33,7 @@ serde_derive = "1.0"
webpki-roots = "0.14.0"
ct-logs = "0.3"
regex = "0.2"
vecio = "0.1"
[[example]]
name = "bogo_shim"

View File

@ -18,6 +18,8 @@ Rustls is currently in development and hence unstable. [Here's what I'm working
- Add support for `SSLKEYLOGFILE`; not enabled by default.
- Add support for basic usage in QUIC.
- `ServerConfig::set_single_cert` and company now report errors.
- Add support for vectored IO: `writev_tls` can now be used to
optimise system call usage.
* 0.12.0 (2018-01-06):
- New API for learning negotiated cipher suite.
- Move TLS1.3 support from draft 18 to 22.

View File

@ -22,6 +22,10 @@ extern crate rustls;
extern crate webpki;
extern crate webpki_roots;
extern crate ct_logs;
extern crate vecio;
mod util;
use util::WriteVAdapter;
use rustls::Session;
@ -144,7 +148,7 @@ impl TlsClient {
}
fn do_write(&mut self) {
self.tls_session.write_tls(&mut self.socket).unwrap();
self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket)).unwrap();
}
fn register(&self, poll: &mut mio::Poll) {

View File

@ -18,12 +18,15 @@ extern crate docopt;
use docopt::Docopt;
extern crate env_logger;
extern crate vecio;
extern crate rustls;
use rustls::{RootCertStore, Session, NoClientAuth, AllowAnyAuthenticatedClient,
AllowAnyAnonymousOrAuthenticatedClient};
mod util;
use util::WriteVAdapter;
// Token for our listening socket.
const LISTENER: mio::Token = mio::Token(0);
@ -301,7 +304,7 @@ impl Connection {
}
fn do_tls_write(&mut self) {
let rc = self.tls_session.write_tls(&mut self.socket);
let rc = self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket));
if rc.is_err() {
error!("write failed {:?}", rc);
self.closing = true;

20
examples/util/mod.rs Normal file
View File

@ -0,0 +1,20 @@
use std::io;
use vecio::Rawv;
use rustls;
/// This glues our `rustls::WriteV` trait to `vecio::Rawv`.
pub struct WriteVAdapter<'a> {
rawv: &'a mut Rawv
}
impl<'a> WriteVAdapter<'a> {
pub fn new(rawv: &'a mut Rawv) -> WriteVAdapter<'a> {
WriteVAdapter { rawv }
}
}
impl<'a> rustls::WriteV for WriteVAdapter<'a> {
fn writev(&mut self, bytes: &[&[u8]]) -> io::Result<usize> {
self.rawv.writev(bytes)
}
}

View File

@ -13,6 +13,7 @@ use anchors;
use sign;
use error::TLSError;
use key;
use vecbuf::WriteV;
use std::sync::Arc;
use std::io;
@ -463,6 +464,10 @@ impl Session for ClientSession {
self.imp.common.write_tls(wr)
}
fn writev_tls(&mut self, wr: &mut WriteV) -> io::Result<usize> {
self.imp.common.writev_tls(wr)
}
fn process_new_packets(&mut self) -> Result<(), TLSError> {
self.imp.process_new_packets()
}

View File

@ -301,6 +301,7 @@ pub use verify::{NoClientAuth, AllowAnyAuthenticatedClient,
pub use suites::{ALL_CIPHERSUITES, SupportedCipherSuite};
pub use key::{Certificate, PrivateKey};
pub use keylog::{KeyLog, NoKeyLog, KeyLogFile};
pub use vecbuf::WriteV;
/// Message signing interfaces and implementations.
pub mod sign;

View File

@ -9,6 +9,8 @@ use error::TLSError;
use sign;
use verify;
use key;
use vecbuf::WriteV;
use webpki;
use std::sync::Arc;
@ -456,6 +458,10 @@ impl Session for ServerSession {
self.imp.common.write_tls(wr)
}
fn writev_tls(&mut self, wr: &mut WriteV) -> io::Result<usize> {
self.imp.common.writev_tls(wr)
}
fn process_new_packets(&mut self) -> Result<(), TLSError> {
self.imp.process_new_packets()
}

View File

@ -11,7 +11,7 @@ use msgs::enums::KeyUpdateRequest;
use error::TLSError;
use suites::SupportedCipherSuite;
use cipher::{MessageDecrypter, MessageEncrypter, self};
use vecbuf::ChunkVecBuffer;
use vecbuf::{ChunkVecBuffer, WriteV};
use key;
use key_schedule::{SecretKind, KeySchedule};
use prf;
@ -51,6 +51,11 @@ pub trait Session: quic::QuicExt + Read + Write + Send + Sync {
/// [`wants_write`]: #tymethod.wants_write
fn write_tls(&mut self, wr: &mut Write) -> Result<usize, io::Error>;
/// Like `write_tls`, but writes potentially many records in one
/// go via `wr`; a `rustls::WriteV`. This function has the same semantics
/// as `write_tls` otherwise.
fn writev_tls(&mut self, wr: &mut WriteV) -> Result<usize, io::Error>;
/// Processes any new packets read by a previous call to `read_tls`.
/// Errors from this function relate to TLS protocol errors, and
/// are fatal to the session. Future calls after an error will do
@ -638,6 +643,10 @@ impl SessionCommon {
self.sendable_tls.write_to(wr)
}
pub fn writev_tls(&mut self, wr: &mut WriteV) -> io::Result<usize> {
self.sendable_tls.writev_to(wr)
}
/// Send plaintext application data, fragmenting and
/// encrypting it as it goes out.
///

View File

@ -3,6 +3,25 @@ use std::io;
use std::cmp;
use std::collections::VecDeque;
/// This trait specifies rustls's precise requirements doing writes with
/// vectored IO.
///
/// The purpose of vectored IO is to pass contigious output in many blocks
/// to the kernel without either coalescing it in user-mode (by allocating
/// and copying) or making many system calls.
///
/// We don't directly use types from the vecio crate because the traits
/// don't compose well: the most useful trait (`Rawv`) is hard to test
/// with (it can't be implemented without an FD) and implies a readable
/// source too. You will have to write a trivial adaptor struct which
/// glues either `vecio::Rawv` or `vecio::Writev` to this trait. See
/// the rustls examples.
pub trait WriteV {
/// Writes as much data from `vbytes` as possible, returning
/// the number of bytes written.
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize>;
}
/// This is a byte buffer that is built from a vector
/// of byte vectors. This avoids extra copies when
/// appending a new byte vector, at the expense of
@ -87,33 +106,49 @@ impl ChunkVecBuffer {
while offs < buf.len() && !self.is_empty() {
let used = self.chunks[0].as_slice().read(&mut buf[offs..])?;
if used == self.chunks[0].len() {
self.take_one();
} else {
self.chunks[0] = self.chunks[0].split_off(used);
}
self.consume(used);
offs += used;
}
Ok(offs)
}
/// Read data of this object, passing it `wr`
fn consume(&mut self, mut used: usize) {
while used > 0 && !self.is_empty() {
if used >= self.chunks[0].len() {
used -= self.chunks[0].len();
self.take_one();
} else {
self.chunks[0] = self.chunks[0].split_off(used);
used = 0;
}
}
}
/// Read data out of this object, passing it `wr`
pub fn write_to(&mut self, wr: &mut io::Write) -> io::Result<usize> {
// would desperately like writev support here!
if self.is_empty() {
return Ok(0);
}
let used = wr.write(&self.chunks[0])?;
self.consume(used);
Ok(used)
}
if used == self.chunks[0].len() {
self.take_one();
} else {
self.chunks[0] = self.chunks[0].split_off(used);
pub fn writev_to(&mut self, wr: &mut WriteV) -> io::Result<usize> {
if self.is_empty() {
return Ok(0);
}
let used = {
let chunks = self.chunks.iter()
.map(|ch| ch.as_ref())
.collect::<Vec<&[u8]>>();
wr.writev(&chunks)?
};
self.consume(used);
Ok(used)
}
}

View File

@ -692,17 +692,29 @@ struct OtherSession<'a> {
sess: &'a mut Session,
pub reads: usize,
pub writes: usize,
pub writevs: Vec<Vec<usize>>,
fail_ok: bool,
pub short_writes: bool,
pub last_error: Option<rustls::TLSError>,
}
impl<'a> OtherSession<'a> {
fn new(sess: &'a mut Session) -> OtherSession<'a> {
OtherSession { sess, reads: 0, writes: 0, fail_ok: false, last_error: None, }
OtherSession {
sess,
reads: 0,
writes: 0,
writevs: vec![],
fail_ok: false,
short_writes: false,
last_error: None,
}
}
fn new_fails(sess: &'a mut Session) -> OtherSession<'a> {
OtherSession { sess, reads: 0, writes: 0, fail_ok: true, last_error: None, }
let mut os = OtherSession::new(sess);
os.fail_ok = true;
os
}
}
@ -733,6 +745,37 @@ impl<'a> io::Write for OtherSession<'a> {
}
}
impl<'a> rustls::WriteV for OtherSession<'a> {
fn writev(&mut self, b: &[&[u8]]) -> io::Result<usize> {
let mut total = 0;
let mut lengths = vec![];
for bytes in b {
let write_len = if self.short_writes {
if bytes.len() > 5 { bytes.len() / 2 } else { bytes.len() }
} else {
bytes.len()
};
let l = self.sess.read_tls(&mut io::Cursor::new(&bytes[..write_len]))?;
lengths.push(l);
total += l;
if bytes.len() != l {
break;
}
}
let rc = self.sess.process_new_packets();
if !self.fail_ok {
rc.unwrap();
} else if rc.is_err() {
self.last_error = rc.err();
}
self.writevs.push(lengths);
Ok(total)
}
}
#[test]
fn client_complete_io_for_handshake() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost"));
@ -1356,3 +1399,130 @@ fn key_log_for_tls13() {
assert_eq!(client_resume_log[3], server_resume_log[3]);
assert_eq!(client_resume_log[4], server_resume_log[4]);
}
#[test]
fn vectored_write_for_server_appdata() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost"));
let mut server = ServerSession::new(&Arc::new(make_server_config()));
do_handshake(&mut client, &mut server);
server.write(b"01234567890123456789").unwrap();
server.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut client);
let wrlen = server.writev_tls(&mut pipe).unwrap();
assert_eq!(84, wrlen);
assert_eq!(pipe.writevs, vec![vec![42, 42]]);
}
check_read(&mut client, b"0123456789012345678901234567890123456789");
}
#[test]
fn vectored_write_for_client_appdata() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost"));
let mut server = ServerSession::new(&Arc::new(make_server_config()));
do_handshake(&mut client, &mut server);
client.write(b"01234567890123456789").unwrap();
client.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut server);
let wrlen = client.writev_tls(&mut pipe).unwrap();
assert_eq!(84, wrlen);
assert_eq!(pipe.writevs, vec![vec![42, 42]]);
}
check_read(&mut server, b"0123456789012345678901234567890123456789");
}
#[test]
fn vectored_write_for_server_handshake() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost"));
let mut server = ServerSession::new(&Arc::new(make_server_config()));
server.write(b"01234567890123456789").unwrap();
server.write(b"0123456789").unwrap();
transfer(&mut client, &mut server);
server.process_new_packets().unwrap();
{
let mut pipe = OtherSession::new(&mut client);
let wrlen = server.writev_tls(&mut pipe).unwrap();
// don't assert exact sizes here, to avoid a brittle test
assert!(wrlen > 5000); // its pretty big (contains cert chain)
assert_eq!(pipe.writevs.len(), 1); // only one writev
assert!(pipe.writevs[0].len() > 3); // at least a server hello/cert/serverkx
}
client.process_new_packets().unwrap();
transfer(&mut client, &mut server);
server.process_new_packets().unwrap();
{
let mut pipe = OtherSession::new(&mut client);
let wrlen = server.writev_tls(&mut pipe).unwrap();
assert_eq!(wrlen, 74);
assert_eq!(pipe.writevs, vec![vec![42, 32]]);
}
assert_eq!(server.is_handshaking(), false);
assert_eq!(client.is_handshaking(), false);
check_read(&mut client, b"012345678901234567890123456789");
}
#[test]
fn vectored_write_for_client_handshake() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost"));
let mut server = ServerSession::new(&Arc::new(make_server_config()));
client.write(b"01234567890123456789").unwrap();
client.write(b"0123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut server);
let wrlen = client.writev_tls(&mut pipe).unwrap();
// don't assert exact sizes here, to avoid a brittle test
assert!(wrlen > 200); // just the client hello
assert_eq!(pipe.writevs.len(), 1); // only one writev
assert!(pipe.writevs[0].len() == 1); // only a client hello
}
transfer(&mut server, &mut client);
client.process_new_packets().unwrap();
{
let mut pipe = OtherSession::new(&mut server);
let wrlen = client.writev_tls(&mut pipe).unwrap();
assert_eq!(wrlen, 138);
// CCS, finished, then two application datas
assert_eq!(pipe.writevs, vec![vec![6, 58, 42, 32]]);
}
assert_eq!(server.is_handshaking(), false);
assert_eq!(client.is_handshaking(), false);
check_read(&mut server, b"012345678901234567890123456789");
}
#[test]
fn vectored_write_with_slow_client() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), dns_name("localhost"));
let mut server = ServerSession::new(&Arc::new(make_server_config()));
client.set_buffer_limit(32);
do_handshake(&mut client, &mut server);
server.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut client);
pipe.short_writes = true;
let wrlen = server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap();
assert_eq!(42, wrlen);
assert_eq!(pipe.writevs, vec![vec![21], vec![10], vec![5], vec![3], vec![3]]);
}
check_read(&mut client, b"01234567890123456789");
}