mirror of https://github.com/ctz/rustls
Remove writev_tls; use std::io::Write::write_vectored
This is just a whole lot better.
This commit is contained in:
parent
2912dbffde
commit
e680b3b6c7
|
@ -28,6 +28,8 @@ If you'd like to help out, please see [CONTRIBUTING.md](CONTRIBUTING.md).
|
|||
- Performance improvements. Thanks to @nviennot.
|
||||
- Fixed client authentication being unduly rejected by client when server
|
||||
uses the superseded certificate_types field of CertificateRequest.
|
||||
- *Breaking API change*: The writev_tls API has been removed, in favour
|
||||
of using vectored IO support now offered by std::io::Write.
|
||||
* 0.17.0 (2020-02-22):
|
||||
- *Breaking API change*: ALPN protocols offered by the client are passed
|
||||
to the server certificate resolution trait (`ResolvesServerCert`).
|
||||
|
|
|
@ -28,7 +28,6 @@ regex = "1.0"
|
|||
serde = "1.0"
|
||||
serde_derive = "1.0"
|
||||
tempfile = "3.0"
|
||||
vecio = "0.1"
|
||||
webpki-roots = "0.19"
|
||||
|
||||
[[example]]
|
||||
|
|
|
@ -23,9 +23,6 @@ use webpki;
|
|||
use webpki_roots;
|
||||
use ct_logs;
|
||||
|
||||
|
||||
mod util;
|
||||
|
||||
use rustls::Session;
|
||||
|
||||
const CLIENT: mio::Token = mio::Token(0);
|
||||
|
@ -146,17 +143,10 @@ impl TlsClient {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn do_write(&mut self) {
|
||||
self.tls_session.write_tls(&mut self.socket).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn do_write(&mut self) {
|
||||
use crate::util::WriteVAdapter;
|
||||
self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket)).unwrap();
|
||||
}
|
||||
|
||||
fn register(&mut self, registry: &mio::Registry) {
|
||||
let interest = self.ready_interest();
|
||||
registry.register(&mut self.socket, CLIENT, interest).unwrap();
|
||||
|
|
|
@ -24,8 +24,6 @@ use rustls;
|
|||
use rustls::{RootCertStore, Session, NoClientAuth, AllowAnyAuthenticatedClient,
|
||||
AllowAnyAnonymousOrAuthenticatedClient};
|
||||
|
||||
mod util;
|
||||
|
||||
// Token for our listening socket.
|
||||
const LISTENER: mio::Token = mio::Token(0);
|
||||
|
||||
|
@ -310,17 +308,10 @@ impl Connection {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn tls_write(&mut self) -> io::Result<usize> {
|
||||
self.tls_session.write_tls(&mut self.socket)
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn tls_write(&mut self) -> io::Result<usize> {
|
||||
use crate::util::WriteVAdapter;
|
||||
self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket))
|
||||
}
|
||||
|
||||
fn do_tls_write_and_handle_error(&mut self) {
|
||||
let rc = self.tls_write();
|
||||
if rc.is_err() {
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
use std::io;
|
||||
use vecio::Rawv;
|
||||
use rustls;
|
||||
|
||||
/// This glues our `rustls::WriteV` trait to `vecio::Rawv`.
|
||||
pub struct WriteVAdapter<'a> {
|
||||
rawv: &'a mut dyn Rawv
|
||||
}
|
||||
|
||||
impl<'a> WriteVAdapter<'a> {
|
||||
pub fn new(rawv: &'a mut dyn Rawv) -> WriteVAdapter<'a> {
|
||||
WriteVAdapter { rawv }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> rustls::WriteV for WriteVAdapter<'a> {
|
||||
fn writev(&mut self, bytes: &[&[u8]]) -> io::Result<usize> {
|
||||
self.rawv.writev(bytes)
|
||||
}
|
||||
}
|
|
@ -13,7 +13,6 @@ use crate::anchors;
|
|||
use crate::sign;
|
||||
use crate::error::TLSError;
|
||||
use crate::key;
|
||||
use crate::vecbuf::WriteV;
|
||||
#[cfg(feature = "logging")]
|
||||
use crate::log::trace;
|
||||
|
||||
|
@ -666,10 +665,6 @@ impl Session for ClientSession {
|
|||
self.imp.common.write_tls(wr)
|
||||
}
|
||||
|
||||
fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result<usize> {
|
||||
self.imp.common.writev_tls(wr)
|
||||
}
|
||||
|
||||
fn process_new_packets(&mut self) -> Result<(), TLSError> {
|
||||
self.imp.process_new_packets()
|
||||
}
|
||||
|
|
|
@ -280,7 +280,6 @@ pub use crate::verify::{NoClientAuth, AllowAnyAuthenticatedClient,
|
|||
pub use crate::suites::{ALL_CIPHERSUITES, BulkAlgorithm, SupportedCipherSuite};
|
||||
pub use crate::key::{Certificate, PrivateKey};
|
||||
pub use crate::keylog::{KeyLog, NoKeyLog, KeyLogFile};
|
||||
pub use crate::vecbuf::{WriteV, WriteVAdapter};
|
||||
|
||||
/// Message signing interfaces and implementations.
|
||||
pub mod sign;
|
||||
|
|
|
@ -10,7 +10,6 @@ use crate::error::TLSError;
|
|||
use crate::sign;
|
||||
use crate::verify;
|
||||
use crate::key;
|
||||
use crate::vecbuf::WriteV;
|
||||
#[cfg(feature = "logging")]
|
||||
use crate::log::trace;
|
||||
|
||||
|
@ -577,10 +576,6 @@ impl Session for ServerSession {
|
|||
self.imp.common.write_tls(wr)
|
||||
}
|
||||
|
||||
fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result<usize> {
|
||||
self.imp.common.writev_tls(wr)
|
||||
}
|
||||
|
||||
fn process_new_packets(&mut self) -> Result<(), TLSError> {
|
||||
self.imp.process_new_packets()
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::msgs::enums::{ContentType, ProtocolVersion, AlertDescription, AlertLe
|
|||
use crate::error::TLSError;
|
||||
use crate::suites::SupportedCipherSuite;
|
||||
use crate::cipher;
|
||||
use crate::vecbuf::{ChunkVecBuffer, WriteV};
|
||||
use crate::vecbuf::ChunkVecBuffer;
|
||||
use crate::key;
|
||||
use crate::prf;
|
||||
use crate::rand;
|
||||
|
@ -52,11 +52,6 @@ pub trait Session: quic::QuicExt + Read + Write + Send + Sync {
|
|||
/// [`wants_write`]: #tymethod.wants_write
|
||||
fn write_tls(&mut self, wr: &mut dyn Write) -> Result<usize, io::Error>;
|
||||
|
||||
/// Like `write_tls`, but writes potentially many records in one
|
||||
/// go via `wr`; a `rustls::WriteV`. This function has the same semantics
|
||||
/// as `write_tls` otherwise.
|
||||
fn writev_tls(&mut self, wr: &mut dyn WriteV) -> Result<usize, io::Error>;
|
||||
|
||||
/// Processes any new packets read by a previous call to `read_tls`.
|
||||
/// Errors from this function relate to TLS protocol errors, and
|
||||
/// are fatal to the session. Future calls after an error will do
|
||||
|
@ -591,10 +586,6 @@ impl SessionCommon {
|
|||
self.sendable_tls.write_to(wr)
|
||||
}
|
||||
|
||||
pub fn writev_tls(&mut self, wr: &mut dyn WriteV) -> io::Result<usize> {
|
||||
self.sendable_tls.writev_to(wr)
|
||||
}
|
||||
|
||||
/// Send plaintext application data, fragmenting and
|
||||
/// encrypting it as it goes out.
|
||||
///
|
||||
|
|
|
@ -2,26 +2,6 @@ use std::io::Read;
|
|||
use std::io;
|
||||
use std::cmp;
|
||||
use std::collections::VecDeque;
|
||||
use std::convert;
|
||||
|
||||
/// This trait specifies rustls's precise requirements doing writes with
|
||||
/// vectored IO.
|
||||
///
|
||||
/// The purpose of vectored IO is to pass contigious output in many blocks
|
||||
/// to the kernel without either coalescing it in user-mode (by allocating
|
||||
/// and copying) or making many system calls.
|
||||
///
|
||||
/// We don't directly use types from the vecio crate because the traits
|
||||
/// don't compose well: the most useful trait (`Rawv`) is hard to test
|
||||
/// with (it can't be implemented without an FD) and implies a readable
|
||||
/// source too. You will have to write a trivial adaptor struct which
|
||||
/// glues either `vecio::Rawv` or `vecio::Writev` to this trait. See
|
||||
/// the rustls examples.
|
||||
pub trait WriteV {
|
||||
/// Writes as much data from `vbytes` as possible, returning
|
||||
/// the number of bytes written.
|
||||
fn writev(&mut self, vbytes: &[&[u8]]) -> io::Result<usize>;
|
||||
}
|
||||
|
||||
/// This is a byte buffer that is built from a vector
|
||||
/// of byte vectors. This avoids extra copies when
|
||||
|
@ -132,50 +112,12 @@ impl ChunkVecBuffer {
|
|||
return Ok(0);
|
||||
}
|
||||
|
||||
let used = wr.write(&self.chunks[0])?;
|
||||
let used = wr.write_vectored(&self.chunks.iter()
|
||||
.map(|ch| io::IoSlice::new(ch))
|
||||
.collect::<Vec<io::IoSlice>>())?;
|
||||
self.consume(used);
|
||||
Ok(used)
|
||||
}
|
||||
|
||||
pub fn writev_to(&mut self, wr: &mut dyn WriteV) -> io::Result<usize> {
|
||||
if self.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let used = {
|
||||
let chunks = self.chunks.iter()
|
||||
.map(convert::AsRef::as_ref)
|
||||
.collect::<Vec<&[u8]>>();
|
||||
|
||||
wr.writev(&chunks)?
|
||||
};
|
||||
self.consume(used);
|
||||
Ok(used)
|
||||
}
|
||||
}
|
||||
|
||||
/// This is a simple wrapper around an object
|
||||
/// which implements `std::io::Write` in order to autoimplement `WriteV`.
|
||||
/// It uses the `write_vectored` method from `std::io::Write` in order
|
||||
/// to do this.
|
||||
pub struct WriteVAdapter<T: io::Write>(T);
|
||||
|
||||
impl<T: io::Write> WriteVAdapter<T> {
|
||||
/// build an adapter from a Write object
|
||||
pub fn new(inner: T) -> Self {
|
||||
WriteVAdapter(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: io::Write> WriteV for WriteVAdapter<T> {
|
||||
fn writev(&mut self, buffers: &[&[u8]]) -> io::Result<usize> {
|
||||
self.0.write_vectored(
|
||||
&buffers
|
||||
.iter()
|
||||
.map(|b| io::IoSlice::new(b))
|
||||
.collect::<Vec<io::IoSlice>>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -6,7 +6,7 @@ use std::mem;
|
|||
use std::fmt;
|
||||
use std::env;
|
||||
use std::error::Error;
|
||||
use std::io::{self, Write, Read};
|
||||
use std::io::{self, Write, Read, IoSlice};
|
||||
|
||||
use rustls;
|
||||
|
||||
|
@ -842,6 +842,23 @@ fn server_respects_buffer_limit_pre_handshake() {
|
|||
check_read(&mut client, b"01234567890123456789012345678901");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_respects_buffer_limit_pre_handshake_with_vectored_write() {
|
||||
let (mut client, mut server) = make_pair(KeyType::RSA);
|
||||
|
||||
server.set_buffer_limit(32);
|
||||
|
||||
assert_eq!(server.write_vectored(&[IoSlice::new(b"01234567890123456789"),
|
||||
IoSlice::new(b"01234567890123456789")]).unwrap(),
|
||||
32);
|
||||
|
||||
do_handshake(&mut client, &mut server);
|
||||
transfer(&mut server, &mut client);
|
||||
client.process_new_packets().unwrap();
|
||||
|
||||
check_read(&mut client, b"01234567890123456789012345678901");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_respects_buffer_limit_post_handshake() {
|
||||
let (mut client, mut server) = make_pair(KeyType::RSA);
|
||||
|
@ -875,6 +892,23 @@ fn client_respects_buffer_limit_pre_handshake() {
|
|||
check_read(&mut server, b"01234567890123456789012345678901");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_respects_buffer_limit_pre_handshake_with_vectored_write() {
|
||||
let (mut client, mut server) = make_pair(KeyType::RSA);
|
||||
|
||||
client.set_buffer_limit(32);
|
||||
|
||||
assert_eq!(client.write_vectored(&[IoSlice::new(b"01234567890123456789"),
|
||||
IoSlice::new(b"01234567890123456789")]).unwrap(),
|
||||
32);
|
||||
|
||||
do_handshake(&mut client, &mut server);
|
||||
transfer(&mut client, &mut server);
|
||||
server.process_new_packets().unwrap();
|
||||
|
||||
check_read(&mut server, b"01234567890123456789012345678901");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_respects_buffer_limit_post_handshake() {
|
||||
let (mut client, mut server) = make_pair(KeyType::RSA);
|
||||
|
@ -894,7 +928,6 @@ fn client_respects_buffer_limit_post_handshake() {
|
|||
struct OtherSession<'a> {
|
||||
sess: &'a mut dyn Session,
|
||||
pub reads: usize,
|
||||
pub writes: usize,
|
||||
pub writevs: Vec<Vec<usize>>,
|
||||
fail_ok: bool,
|
||||
pub short_writes: bool,
|
||||
|
@ -906,7 +939,6 @@ impl<'a> OtherSession<'a> {
|
|||
OtherSession {
|
||||
sess,
|
||||
reads: 0,
|
||||
writes: 0,
|
||||
writevs: vec![],
|
||||
fail_ok: false,
|
||||
short_writes: false,
|
||||
|
@ -929,27 +961,15 @@ impl<'a> io::Read for OtherSession<'a> {
|
|||
}
|
||||
|
||||
impl<'a> io::Write for OtherSession<'a> {
|
||||
fn write(&mut self, mut b: &[u8]) -> io::Result<usize> {
|
||||
self.writes += 1;
|
||||
let l = self.sess.read_tls(b.by_ref())?;
|
||||
let rc = self.sess.process_new_packets();
|
||||
|
||||
if !self.fail_ok {
|
||||
rc.unwrap();
|
||||
} else if rc.is_err() {
|
||||
self.last_error = rc.err();
|
||||
}
|
||||
|
||||
Ok(l)
|
||||
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> rustls::WriteV for OtherSession<'a> {
|
||||
fn writev(&mut self, b: &[&[u8]]) -> io::Result<usize> {
|
||||
fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result<usize> {
|
||||
let mut total = 0;
|
||||
let mut lengths = vec![];
|
||||
for bytes in b {
|
||||
|
@ -1012,7 +1032,8 @@ fn client_complete_io_for_write() {
|
|||
let mut pipe = OtherSession::new(&mut server);
|
||||
let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
|
||||
assert!(rdlen == 0 && wrlen > 0);
|
||||
assert_eq!(pipe.writes, 2);
|
||||
println!("{:?}", pipe.writevs);
|
||||
assert_eq!(pipe.writevs, vec![ vec![ 42, 42 ] ]);
|
||||
}
|
||||
check_read(&mut server, b"0123456789012345678901234567890123456789");
|
||||
}
|
||||
|
@ -1071,7 +1092,7 @@ fn server_complete_io_for_write() {
|
|||
let mut pipe = OtherSession::new(&mut client);
|
||||
let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap();
|
||||
assert!(rdlen == 0 && wrlen > 0);
|
||||
assert_eq!(pipe.writes, 2);
|
||||
assert_eq!(pipe.writevs, vec![ vec![ 42, 42 ] ]);
|
||||
}
|
||||
check_read(&mut client, b"0123456789012345678901234567890123456789");
|
||||
}
|
||||
|
@ -1774,7 +1795,7 @@ fn vectored_write_for_server_appdata() {
|
|||
server.write(b"01234567890123456789").unwrap();
|
||||
{
|
||||
let mut pipe = OtherSession::new(&mut client);
|
||||
let wrlen = server.writev_tls(&mut pipe).unwrap();
|
||||
let wrlen = server.write_tls(&mut pipe).unwrap();
|
||||
assert_eq!(84, wrlen);
|
||||
assert_eq!(pipe.writevs, vec![vec![42, 42]]);
|
||||
}
|
||||
|
@ -1790,7 +1811,7 @@ fn vectored_write_for_client_appdata() {
|
|||
client.write(b"01234567890123456789").unwrap();
|
||||
{
|
||||
let mut pipe = OtherSession::new(&mut server);
|
||||
let wrlen = client.writev_tls(&mut pipe).unwrap();
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
assert_eq!(84, wrlen);
|
||||
assert_eq!(pipe.writevs, vec![vec![42, 42]]);
|
||||
}
|
||||
|
@ -1808,7 +1829,7 @@ fn vectored_write_for_server_handshake() {
|
|||
server.process_new_packets().unwrap();
|
||||
{
|
||||
let mut pipe = OtherSession::new(&mut client);
|
||||
let wrlen = server.writev_tls(&mut pipe).unwrap();
|
||||
let wrlen = server.write_tls(&mut pipe).unwrap();
|
||||
// don't assert exact sizes here, to avoid a brittle test
|
||||
assert!(wrlen > 4000); // its pretty big (contains cert chain)
|
||||
assert_eq!(pipe.writevs.len(), 1); // only one writev
|
||||
|
@ -1820,7 +1841,7 @@ fn vectored_write_for_server_handshake() {
|
|||
server.process_new_packets().unwrap();
|
||||
{
|
||||
let mut pipe = OtherSession::new(&mut client);
|
||||
let wrlen = server.writev_tls(&mut pipe).unwrap();
|
||||
let wrlen = server.write_tls(&mut pipe).unwrap();
|
||||
assert_eq!(wrlen, 177);
|
||||
assert_eq!(pipe.writevs, vec![vec![103, 42, 32]]);
|
||||
}
|
||||
|
@ -1838,7 +1859,7 @@ fn vectored_write_for_client_handshake() {
|
|||
client.write(b"0123456789").unwrap();
|
||||
{
|
||||
let mut pipe = OtherSession::new(&mut server);
|
||||
let wrlen = client.writev_tls(&mut pipe).unwrap();
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
// don't assert exact sizes here, to avoid a brittle test
|
||||
assert!(wrlen > 200); // just the client hello
|
||||
assert_eq!(pipe.writevs.len(), 1); // only one writev
|
||||
|
@ -1850,7 +1871,7 @@ fn vectored_write_for_client_handshake() {
|
|||
|
||||
{
|
||||
let mut pipe = OtherSession::new(&mut server);
|
||||
let wrlen = client.writev_tls(&mut pipe).unwrap();
|
||||
let wrlen = client.write_tls(&mut pipe).unwrap();
|
||||
assert_eq!(wrlen, 138);
|
||||
// CCS, finished, then two application datas
|
||||
assert_eq!(pipe.writevs, vec![vec![6, 58, 42, 32]]);
|
||||
|
@ -1873,12 +1894,12 @@ fn vectored_write_with_slow_client() {
|
|||
{
|
||||
let mut pipe = OtherSession::new(&mut client);
|
||||
pipe.short_writes = true;
|
||||
let wrlen = server.writev_tls(&mut pipe).unwrap() +
|
||||
server.writev_tls(&mut pipe).unwrap() +
|
||||
server.writev_tls(&mut pipe).unwrap() +
|
||||
server.writev_tls(&mut pipe).unwrap() +
|
||||
server.writev_tls(&mut pipe).unwrap() +
|
||||
server.writev_tls(&mut pipe).unwrap();
|
||||
let wrlen = server.write_tls(&mut pipe).unwrap() +
|
||||
server.write_tls(&mut pipe).unwrap() +
|
||||
server.write_tls(&mut pipe).unwrap() +
|
||||
server.write_tls(&mut pipe).unwrap() +
|
||||
server.write_tls(&mut pipe).unwrap() +
|
||||
server.write_tls(&mut pipe).unwrap();
|
||||
assert_eq!(42, wrlen);
|
||||
assert_eq!(pipe.writevs, vec![vec![21], vec![10], vec![5], vec![3], vec![3]]);
|
||||
}
|
||||
|
@ -2201,21 +2222,30 @@ fn test_client_does_not_offer_sha1() {
|
|||
|
||||
#[test]
|
||||
fn test_client_mtu_reduction() {
|
||||
fn collect_write_lengths(client: &mut ClientSession) -> Vec<usize> {
|
||||
let mut r = Vec::new();
|
||||
let mut buf = [0u8; 128];
|
||||
struct CollectWrites {
|
||||
writevs: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
loop {
|
||||
let sz = client.write_tls(&mut buf.as_mut())
|
||||
.unwrap();
|
||||
r.push(sz);
|
||||
assert!(sz <= 64);
|
||||
if sz < 64 {
|
||||
break;
|
||||
}
|
||||
impl io::Write for CollectWrites {
|
||||
fn write(&mut self, _: &[u8]) -> io::Result<usize> { panic!() }
|
||||
fn flush(&mut self) -> io::Result<()> { panic!() }
|
||||
fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result<usize> {
|
||||
let writes = b.iter()
|
||||
.map(|slice| slice.len())
|
||||
.collect::<Vec<usize>>();
|
||||
let len = writes.iter().sum();
|
||||
self.writevs.push(writes);
|
||||
Ok(len)
|
||||
}
|
||||
}
|
||||
|
||||
r
|
||||
fn collect_write_lengths(client: &mut ClientSession) -> Vec<usize> {
|
||||
let mut collector = CollectWrites { writevs: vec![] };
|
||||
|
||||
client.write_tls(&mut collector)
|
||||
.unwrap();
|
||||
assert_eq!(collector.writevs.len(), 1);
|
||||
collector.writevs[0].clone()
|
||||
}
|
||||
|
||||
for kt in ALL_KEY_TYPES.iter() {
|
||||
|
@ -2224,6 +2254,7 @@ fn test_client_mtu_reduction() {
|
|||
|
||||
let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
|
||||
let writes = collect_write_lengths(&mut client);
|
||||
println!("writes at mtu=64: {:?}", writes);
|
||||
assert!(writes.iter().all(|x| *x <= 64));
|
||||
assert!(writes.len() > 1);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue