Remove writev_tls; use std::io::Write::write_vectored

This is just a whole lot better.
This commit is contained in:
Joseph Birr-Pixton 2020-05-15 21:39:22 +01:00
parent 2912dbffde
commit e680b3b6c7
11 changed files with 81 additions and 166 deletions

View File

@ -28,6 +28,8 @@ If you'd like to help out, please see [CONTRIBUTING.md](CONTRIBUTING.md).
- Performance improvements. Thanks to @nviennot.
- Fixed client authentication being unduly rejected by client when server
uses the superseded certificate_types field of CertificateRequest.
- *Breaking API change*: The writev_tls API has been removed, in favour
of using vectored IO support now offered by std::io::Write.
* 0.17.0 (2020-02-22):
- *Breaking API change*: ALPN protocols offered by the client are passed
to the server certificate resolution trait (`ResolvesServerCert`).

View File

@ -28,7 +28,6 @@ regex = "1.0"
serde = "1.0"
serde_derive = "1.0"
tempfile = "3.0"
vecio = "0.1"
webpki-roots = "0.19"
[[example]]

View File

@ -23,9 +23,6 @@ use webpki;
use webpki_roots;
use ct_logs;
mod util;
use rustls::Session;
const CLIENT: mio::Token = mio::Token(0);
@ -146,17 +143,10 @@ impl TlsClient {
}
}
#[cfg(target_os = "windows")]
fn do_write(&mut self) {
self.tls_session.write_tls(&mut self.socket).unwrap();
}
#[cfg(not(target_os = "windows"))]
fn do_write(&mut self) {
use crate::util::WriteVAdapter;
self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket)).unwrap();
}
fn register(&mut self, registry: &mio::Registry) {
let interest = self.ready_interest();
registry.register(&mut self.socket, CLIENT, interest).unwrap();

View File

@ -24,8 +24,6 @@ use rustls;
use rustls::{RootCertStore, Session, NoClientAuth, AllowAnyAuthenticatedClient,
AllowAnyAnonymousOrAuthenticatedClient};
mod util;
// Token for our listening socket.
const LISTENER: mio::Token = mio::Token(0);
@ -310,17 +308,10 @@ impl Connection {
}
}
#[cfg(target_os = "windows")]
fn tls_write(&mut self) -> io::Result<usize> {
self.tls_session.write_tls(&mut self.socket)
}
#[cfg(not(target_os = "windows"))]
fn tls_write(&mut self) -> io::Result<usize> {
use crate::util::WriteVAdapter;
self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket))
}
fn do_tls_write_and_handle_error(&mut self) {
let rc = self.tls_write();
if rc.is_err() {

View File

@ -1,20 +0,0 @@
use std::io;
use vecio::Rawv;
use rustls;
/// This glues our `rustls::WriteV` trait to `vecio::Rawv`.
pub struct WriteVAdapter<'a> {
rawv: &'a mut dyn Rawv
}
impl<'a> WriteVAdapter<'a> {
pub fn new(rawv: &'a mut dyn Rawv) -> WriteVAdapter<'a> {
WriteVAdapter { rawv }
}
}
impl<'a> rustls::WriteV for WriteVAdapter<'a> {
fn writev(&mut self, bytes: &[&[u8]]) -> io::Result<usize> {
self.rawv.writev(bytes)
}
}

View File

@ -13,7 +13,6 @@ use crate::anchors;
use crate::sign;
use crate::error::TLSError;
use crate::key;
use crate::vecbuf::WriteV;
#[cfg(feature = "logging")]
use crate::log::trace;
@ -666,10 +665,6 @@ impl Session for ClientSession {
self.imp.common.write_tls(wr)
}
fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result<usize> {
self.imp.common.writev_tls(wr)
}
fn process_new_packets(&mut self) -> Result<(), TLSError> {
self.imp.process_new_packets()
}

View File

@ -280,7 +280,6 @@ pub use crate::verify::{NoClientAuth, AllowAnyAuthenticatedClient,
pub use crate::suites::{ALL_CIPHERSUITES, BulkAlgorithm, SupportedCipherSuite};
pub use crate::key::{Certificate, PrivateKey};
pub use crate::keylog::{KeyLog, NoKeyLog, KeyLogFile};
pub use crate::vecbuf::{WriteV, WriteVAdapter};
/// Message signing interfaces and implementations.
pub mod sign;

View File

@ -10,7 +10,6 @@ use crate::error::TLSError;
use crate::sign;
use crate::verify;
use crate::key;
use crate::vecbuf::WriteV;
#[cfg(feature = "logging")]
use crate::log::trace;
@ -577,10 +576,6 @@ impl Session for ServerSession {
self.imp.common.write_tls(wr)
}
fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result<usize> {
self.imp.common.writev_tls(wr)
}
fn process_new_packets(&mut self) -> Result<(), TLSError> {
self.imp.process_new_packets()
}

View File

@ -10,7 +10,7 @@ use crate::msgs::enums::{ContentType, ProtocolVersion, AlertDescription, AlertLe
use crate::error::TLSError;
use crate::suites::SupportedCipherSuite;
use crate::cipher;
use crate::vecbuf::{ChunkVecBuffer, WriteV};
use crate::vecbuf::ChunkVecBuffer;
use crate::key;
use crate::prf;
use crate::rand;
@ -52,11 +52,6 @@ pub trait Session: quic::QuicExt + Read + Write + Send + Sync {
/// [`wants_write`]: #tymethod.wants_write
fn write_tls(&mut self, wr: &mut dyn Write) -> Result<usize, io::Error>;
/// Like `write_tls`, but writes potentially many records in one
/// go via `wr`; a `rustls::WriteV`. This function has the same semantics
/// as `write_tls` otherwise.
fn writev_tls(&mut self, wr: &mut dyn WriteV) -> Result<usize, io::Error>;
/// Processes any new packets read by a previous call to `read_tls`.
/// Errors from this function relate to TLS protocol errors, and
/// are fatal to the session. Future calls after an error will do
@ -591,10 +586,6 @@ impl SessionCommon {
self.sendable_tls.write_to(wr)
}
pub fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result<usize> {
self.sendable_tls.writev_to(wr)
}
/// Send plaintext application data, fragmenting and
/// encrypting it as it goes out.
///

View File

@ -2,26 +2,6 @@ use std::io::Read;
use std::io;
use std::cmp;
use std::collections::VecDeque;
use std::convert;
/// This trait specifies rustls's precise requirements doing writes with
/// vectored IO.
///
/// The purpose of vectored IO is to pass contigious output in many blocks
/// to the kernel without either coalescing it in user-mode (by allocating
/// and copying) or making many system calls.
///
/// We don't directly use types from the vecio crate because the traits
/// don't compose well: the most useful trait (`Rawv`) is hard to test
/// with (it can't be implemented without an FD) and implies a readable
/// source too. You will have to write a trivial adaptor struct which
/// glues either `vecio::Rawv` or `vecio::Writev` to this trait. See
/// the rustls examples.
pub trait WriteV {
/// Writes as much data from `vbytes` as possible, returning
/// the number of bytes written.
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize>;
}
/// This is a byte buffer that is built from a vector
/// of byte vectors. This avoids extra copies when
@ -132,50 +112,12 @@ impl ChunkVecBuffer {
return Ok(0);
}
let used = wr.write(&self.chunks[0])?;
let used = wr.write_vectored(&self.chunks.iter()
.map(|ch| io::IoSlice::new(ch))
.collect::<Vec<io::IoSlice>>())?;
self.consume(used);
Ok(used)
}
pub fn writev_to(&mut self, wr: &mut dyn WriteV) -> io::Result<usize> {
if self.is_empty() {
return Ok(0);
}
let used = {
let chunks = self.chunks.iter()
.map(convert::AsRef::as_ref)
.collect::<Vec<&[u8]>>();
wr.writev(&chunks)?
};
self.consume(used);
Ok(used)
}
}
/// This is a simple wrapper around an object
/// which implements `std::io::Write` in order to autoimplement `WriteV`.
/// It uses the `write_vectored` method from `std::io::Write` in order
/// to do this.
pub struct WriteVAdapter<T: io::Write>(T);
impl<T: io::Write> WriteVAdapter<T> {
/// build an adapter from a Write object
pub fn new(inner: T) -> Self {
WriteVAdapter(inner)
}
}
impl<T: io::Write> WriteV for WriteVAdapter<T> {
fn writev(&mut self, buffers: &[&[u8]]) -> io::Result<usize> {
self.0.write_vectored(
&buffers
.iter()
.map(|b| io::IoSlice::new(b))
.collect::<Vec<io::IoSlice>>(),
)
}
}
#[cfg(test)]

View File

@ -6,7 +6,7 @@ use std::mem;
use std::fmt;
use std::env;
use std::error::Error;
use std::io::{self, Write, Read};
use std::io::{self, Write, Read, IoSlice};
use rustls;
@ -842,6 +842,23 @@ fn server_respects_buffer_limit_pre_handshake() {
check_read(&mut client, b"01234567890123456789012345678901");
}
#[test]
fn server_respects_buffer_limit_pre_handshake_with_vectored_write() {
let (mut client, mut server) = make_pair(KeyType::RSA);
server.set_buffer_limit(32);
assert_eq!(server.write_vectored(&[IoSlice::new(b"01234567890123456789"),
IoSlice::new(b"01234567890123456789")]).unwrap(),
32);
do_handshake(&mut client, &mut server);
transfer(&mut server, &mut client);
client.process_new_packets().unwrap();
check_read(&mut client, b"01234567890123456789012345678901");
}
#[test]
fn server_respects_buffer_limit_post_handshake() {
let (mut client, mut server) = make_pair(KeyType::RSA);
@ -875,6 +892,23 @@ fn client_respects_buffer_limit_pre_handshake() {
check_read(&mut server, b"01234567890123456789012345678901");
}
#[test]
fn client_respects_buffer_limit_pre_handshake_with_vectored_write() {
let (mut client, mut server) = make_pair(KeyType::RSA);
client.set_buffer_limit(32);
assert_eq!(client.write_vectored(&[IoSlice::new(b"01234567890123456789"),
IoSlice::new(b"01234567890123456789")]).unwrap(),
32);
do_handshake(&mut client, &mut server);
transfer(&mut client, &mut server);
server.process_new_packets().unwrap();
check_read(&mut server, b"01234567890123456789012345678901");
}
#[test]
fn client_respects_buffer_limit_post_handshake() {
let (mut client, mut server) = make_pair(KeyType::RSA);
@ -894,7 +928,6 @@ fn client_respects_buffer_limit_post_handshake() {
struct OtherSession<'a> {
sess: &'a mut dyn Session,
pub reads: usize,
pub writes: usize,
pub writevs: Vec<Vec<usize>>,
fail_ok: bool,
pub short_writes: bool,
@ -906,7 +939,6 @@ impl<'a> OtherSession<'a> {
OtherSession {
sess,
reads: 0,
writes: 0,
writevs: vec![],
fail_ok: false,
short_writes: false,
@ -929,27 +961,15 @@ impl<'a> io::Read for OtherSession<'a> {
}
impl<'a> io::Write for OtherSession<'a> {
fn write(&mut self, mut b: &[u8]) -> io::Result<usize> {
self.writes += 1;
let l = self.sess.read_tls(b.by_ref())?;
let rc = self.sess.process_new_packets();
if !self.fail_ok {
rc.unwrap();
} else if rc.is_err() {
self.last_error = rc.err();
}
Ok(l)
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
unreachable!()
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<'a> rustls::WriteV for OtherSession<'a> {
fn writev(&mut self, b: &[&[u8]]) -> io::Result<usize> {
fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result<usize> {
let mut total = 0;
let mut lengths = vec![];
for bytes in b {
@ -1012,7 +1032,8 @@ fn client_complete_io_for_write() {
let mut pipe = OtherSession::new(&mut server);
let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
assert!(rdlen == 0 && wrlen > 0);
assert_eq!(pipe.writes, 2);
println!("{:?}", pipe.writevs);
assert_eq!(pipe.writevs, vec![ vec![ 42, 42 ] ]);
}
check_read(&mut server, b"0123456789012345678901234567890123456789");
}
@ -1071,7 +1092,7 @@ fn server_complete_io_for_write() {
let mut pipe = OtherSession::new(&mut client);
let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap();
assert!(rdlen == 0 && wrlen > 0);
assert_eq!(pipe.writes, 2);
assert_eq!(pipe.writevs, vec![ vec![ 42, 42 ] ]);
}
check_read(&mut client, b"0123456789012345678901234567890123456789");
}
@ -1774,7 +1795,7 @@ fn vectored_write_for_server_appdata() {
server.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut client);
let wrlen = server.writev_tls(&mut pipe).unwrap();
let wrlen = server.write_tls(&mut pipe).unwrap();
assert_eq!(84, wrlen);
assert_eq!(pipe.writevs, vec![vec![42, 42]]);
}
@ -1790,7 +1811,7 @@ fn vectored_write_for_client_appdata() {
client.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut server);
let wrlen = client.writev_tls(&mut pipe).unwrap();
let wrlen = client.write_tls(&mut pipe).unwrap();
assert_eq!(84, wrlen);
assert_eq!(pipe.writevs, vec![vec![42, 42]]);
}
@ -1808,7 +1829,7 @@ fn vectored_write_for_server_handshake() {
server.process_new_packets().unwrap();
{
let mut pipe = OtherSession::new(&mut client);
let wrlen = server.writev_tls(&mut pipe).unwrap();
let wrlen = server.write_tls(&mut pipe).unwrap();
// don't assert exact sizes here, to avoid a brittle test
assert!(wrlen > 4000); // its pretty big (contains cert chain)
assert_eq!(pipe.writevs.len(), 1); // only one writev
@ -1820,7 +1841,7 @@ fn vectored_write_for_server_handshake() {
server.process_new_packets().unwrap();
{
let mut pipe = OtherSession::new(&mut client);
let wrlen = server.writev_tls(&mut pipe).unwrap();
let wrlen = server.write_tls(&mut pipe).unwrap();
assert_eq!(wrlen, 177);
assert_eq!(pipe.writevs, vec![vec![103, 42, 32]]);
}
@ -1838,7 +1859,7 @@ fn vectored_write_for_client_handshake() {
client.write(b"0123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut server);
let wrlen = client.writev_tls(&mut pipe).unwrap();
let wrlen = client.write_tls(&mut pipe).unwrap();
// don't assert exact sizes here, to avoid a brittle test
assert!(wrlen > 200); // just the client hello
assert_eq!(pipe.writevs.len(), 1); // only one writev
@ -1850,7 +1871,7 @@ fn vectored_write_for_client_handshake() {
{
let mut pipe = OtherSession::new(&mut server);
let wrlen = client.writev_tls(&mut pipe).unwrap();
let wrlen = client.write_tls(&mut pipe).unwrap();
assert_eq!(wrlen, 138);
// CCS, finished, then two application datas
assert_eq!(pipe.writevs, vec![vec![6, 58, 42, 32]]);
@ -1873,12 +1894,12 @@ fn vectored_write_with_slow_client() {
{
let mut pipe = OtherSession::new(&mut client);
pipe.short_writes = true;
let wrlen = server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap() +
server.writev_tls(&mut pipe).unwrap();
let wrlen = server.write_tls(&mut pipe).unwrap() +
server.write_tls(&mut pipe).unwrap() +
server.write_tls(&mut pipe).unwrap() +
server.write_tls(&mut pipe).unwrap() +
server.write_tls(&mut pipe).unwrap() +
server.write_tls(&mut pipe).unwrap();
assert_eq!(42, wrlen);
assert_eq!(pipe.writevs, vec![vec![21], vec![10], vec![5], vec![3], vec![3]]);
}
@ -2201,21 +2222,30 @@ fn test_client_does_not_offer_sha1() {
#[test]
fn test_client_mtu_reduction() {
fn collect_write_lengths(client: &mut ClientSession) -> Vec<usize> {
let mut r = Vec::new();
let mut buf = [0u8; 128];
struct CollectWrites {
writevs: Vec<Vec<usize>>,
}
loop {
let sz = client.write_tls(&mut buf.as_mut())
.unwrap();
r.push(sz);
assert!(sz <= 64);
if sz < 64 {
break;
}
impl io::Write for CollectWrites {
fn write(&mut self, _: &[u8]) -> io::Result<usize> { panic!() }
fn flush(&mut self) -> io::Result<()> { panic!() }
fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result<usize> {
let writes = b.iter()
.map(|slice| slice.len())
.collect::<Vec<usize>>();
let len = writes.iter().sum();
self.writevs.push(writes);
Ok(len)
}
}
r
fn collect_write_lengths(client: &mut ClientSession) -> Vec<usize> {
let mut collector = CollectWrites { writevs: vec![] };
client.write_tls(&mut collector)
.unwrap();
assert_eq!(collector.writevs.len(), 1);
collector.writevs[0].clone()
}
for kt in ALL_KEY_TYPES.iter() {
@ -2224,6 +2254,7 @@ fn test_client_mtu_reduction() {
let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
let writes = collect_write_lengths(&mut client);
println!("writes at mtu=64: {:?}", writes);
assert!(writes.iter().all(|x| *x <= 64));
assert!(writes.len() > 1);
}