From 2f02ddc21b39d31ccacc6a1396ac96d29aa6b867 Mon Sep 17 00:00:00 2001 From: Eloi DEMOLIS Date: Tue, 6 Feb 2024 01:45:56 +0100 Subject: [PATCH] Create type OutboundChunks for OutboundMessage The ConnectionCommon::write_vectored was implemented by processing each chunk, fragmenting them and wrapping each fragment in a OutboundMessage before encrypting and sending it as separate TLS frames. For very fragmented payloads this generates a lot of very small payloads with most of the data being TLS headers. OutboundChunks can contain an arbitrary amount of fragmented chunks. This allows write_vectored to process all its chunks at once, fragmenting it in place if needed and wrapping it in a OutboundMessage. All the chunks are merged in a contiguous vector (taking atvantage of an already existent copy) before being encrypted and sent as a single TLS frame. Signed-off-by: Eloi DEMOLIS Co-Authored-By: Emmanuel Bosquet --- provider-example/src/aead.rs | 4 +- rustls/src/client/client_conn.rs | 2 +- rustls/src/common_state.rs | 67 ++++---- rustls/src/conn.rs | 30 ++-- rustls/src/conn/unbuffered.rs | 2 +- rustls/src/crypto/aws_lc_rs/tls12.rs | 4 +- rustls/src/crypto/aws_lc_rs/tls13.rs | 4 +- rustls/src/crypto/ring/tls12.rs | 4 +- rustls/src/crypto/ring/tls13.rs | 2 +- rustls/src/msgs/fragmenter.rs | 89 ++++++++-- rustls/src/msgs/message.rs | 239 ++++++++++++++++++++++++++- rustls/src/vecbuf.rs | 16 +- rustls/tests/api.rs | 29 ++++ 13 files changed, 419 insertions(+), 73 deletions(-) diff --git a/provider-example/src/aead.rs b/provider-example/src/aead.rs index 0483cf99..97840211 100644 --- a/provider-example/src/aead.rs +++ b/provider-example/src/aead.rs @@ -93,7 +93,7 @@ impl cipher::MessageEncrypter for Tls13Cipher { // construct a TLSInnerPlaintext let mut payload = Vec::with_capacity(total_len); - payload.extend_from_slice(m.payload); + m.payload.copy_to_vec(&mut payload); payload.push(m.typ.get_u8()); let nonce = chacha20poly1305::Nonce::from(cipher::Nonce::new(&self.1, seq).0); @@ -145,7 +145,7 @@ impl cipher::MessageEncrypter for Tls12Cipher { let total_len = self.encrypted_payload_len(m.payload.len()); let mut payload = Vec::with_capacity(total_len); - payload.extend_from_slice(m.payload); + m.payload.copy_to_vec(&mut payload); let nonce = chacha20poly1305::Nonce::from(cipher::Nonce::new(&self.1, seq).0); let aad = cipher::make_tls12_aad(seq, m.typ, m.version, payload.len()); diff --git a/rustls/src/client/client_conn.rs b/rustls/src/client/client_conn.rs index 08aff079..fdc8b555 100644 --- a/rustls/src/client/client_conn.rs +++ b/rustls/src/client/client_conn.rs @@ -787,7 +787,7 @@ impl MayEncryptEarlyData<'_> { self.conn .core .common_state - .write_plaintext(&early_data[..allowed], outgoing_tls) + .write_plaintext(early_data[..allowed].into(), outgoing_tls) .map_err(|e| e.into()) } } diff --git a/rustls/src/common_state.rs b/rustls/src/common_state.rs index 0ae7d11a..69f7fae0 100644 --- a/rustls/src/common_state.rs +++ b/rustls/src/common_state.rs @@ -8,7 +8,8 @@ use crate::msgs::enums::{AlertLevel, KeyUpdateRequest}; use crate::msgs::fragmenter::MessageFragmenter; use crate::msgs::handshake::CertificateChain; use crate::msgs::message::{ - BorrowedPlainMessage, Message, MessagePayload, OpaqueMessage, OutboundMessage, PlainMessage, + BorrowedPlainMessage, Message, MessagePayload, OpaqueMessage, OutboundChunks, OutboundMessage, + PlainMessage, }; use crate::quic; use crate::record_layer; @@ -188,27 +189,29 @@ impl CommonState { /// all the data. pub(crate) fn buffer_plaintext( &mut self, - data: &[u8], + payload: OutboundChunks<'_>, sendable_plaintext: &mut ChunkVecBuffer, ) -> usize { self.perhaps_write_key_update(); - self.send_plain(data, Limit::Yes, sendable_plaintext) + self.send_plain(payload, Limit::Yes, sendable_plaintext) } pub(crate) fn write_plaintext( &mut self, - plaintext: &[u8], + payload: OutboundChunks<'_>, outgoing_tls: &mut [u8], ) -> Result { - if plaintext.is_empty() { + if payload.is_empty() { return Ok(0); } - let fragments = self.message_fragmenter.fragment_slice( - ContentType::ApplicationData, - ProtocolVersion::TLSv1_2, - plaintext, - ); + let fragments = self + .message_fragmenter + .fragment_payload( + ContentType::ApplicationData, + ProtocolVersion::TLSv1_2, + payload.clone(), + ); let remaining_encryptions = self .record_layer @@ -226,11 +229,13 @@ impl CommonState { fragments, )?; - let fragments = self.message_fragmenter.fragment_slice( - ContentType::ApplicationData, - ProtocolVersion::TLSv1_2, - plaintext, - ); + let fragments = self + .message_fragmenter + .fragment_payload( + ContentType::ApplicationData, + ProtocolVersion::TLSv1_2, + payload, + ); let opt_msg = self.queued_key_update_message.take(); let written = self.write_fragments(outgoing_tls, opt_msg, fragments); @@ -247,7 +252,7 @@ impl CommonState { return 0; } - self.send_appdata_encrypt(data, Limit::Yes) + self.send_appdata_encrypt(data.into(), Limit::Yes) } // Changing the keys must not span any fragmented handshake @@ -277,7 +282,7 @@ impl CommonState { } /// Like send_msg_encrypt, but operate on an appdata directly. - fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize { + fn send_appdata_encrypt(&mut self, payload: OutboundChunks<'_>, limit: Limit) -> usize { // Here, the limit on sendable_tls applies to encrypted data, // but we're respecting it for plaintext data -- so we'll // be out by whatever the cipher+record overhead is. That's a @@ -289,11 +294,13 @@ impl CommonState { Limit::No => payload.len(), }; - let iter = self.message_fragmenter.fragment_slice( - ContentType::ApplicationData, - ProtocolVersion::TLSv1_2, - &payload[..len], - ); + let iter = self + .message_fragmenter + .fragment_payload( + ContentType::ApplicationData, + ProtocolVersion::TLSv1_2, + payload.split_at(len).0, + ); for m in iter { self.send_single_fragment(m); } @@ -328,7 +335,7 @@ impl CommonState { /// be less than `data.len()` if buffer limits were exceeded. fn send_plain( &mut self, - data: &[u8], + payload: OutboundChunks<'_>, limit: Limit, sendable_plaintext: &mut ChunkVecBuffer, ) -> usize { @@ -336,25 +343,25 @@ impl CommonState { // If we haven't completed handshaking, buffer // plaintext to send once we do. let len = match limit { - Limit::Yes => sendable_plaintext.append_limited_copy(data), - Limit::No => sendable_plaintext.append(data.to_vec()), + Limit::Yes => sendable_plaintext.append_limited_copy(payload), + Limit::No => sendable_plaintext.append(payload.to_vec()), }; return len; } - self.send_plain_non_buffering(data, limit) + self.send_plain_non_buffering(payload, limit) } - fn send_plain_non_buffering(&mut self, data: &[u8], limit: Limit) -> usize { + fn send_plain_non_buffering(&mut self, payload: OutboundChunks<'_>, limit: Limit) -> usize { debug_assert!(self.may_send_application_data); debug_assert!(self.record_layer.is_encrypting()); - if data.is_empty() { + if payload.is_empty() { // Don't send empty fragments. return 0; } - self.send_appdata_encrypt(data, limit) + self.send_appdata_encrypt(payload, limit) } /// Mark the connection as ready to send application data. @@ -386,7 +393,7 @@ impl CommonState { } while let Some(buf) = sendable_plaintext.pop() { - self.send_plain_non_buffering(&buf, Limit::No); + self.send_plain_non_buffering(buf.as_slice().into(), Limit::No); } } diff --git a/rustls/src/conn.rs b/rustls/src/conn.rs index bdd0faef..c88456d8 100644 --- a/rustls/src/conn.rs +++ b/rustls/src/conn.rs @@ -5,11 +5,12 @@ use crate::error::{Error, PeerMisbehaved}; use crate::log::trace; use crate::msgs::deframer::{Deframed, DeframerSliceBuffer, DeframerVecBuffer, MessageDeframer}; use crate::msgs::handshake::Random; -use crate::msgs::message::{InboundMessage, Message, MessagePayload}; +use crate::msgs::message::{InboundMessage, Message, MessagePayload, OutboundChunks}; use crate::suites::{ExtractedSecrets, PartiallyExtractedSecrets}; use crate::vecbuf::ChunkVecBuffer; use alloc::boxed::Box; +use alloc::vec::Vec; use core::fmt::Debug; use core::mem; use core::ops::{Deref, DerefMut}; @@ -257,18 +258,27 @@ impl PlaintextSink for ConnectionCommon { Ok(self .core .common_state - .buffer_plaintext(buf, &mut self.sendable_plaintext)) + .buffer_plaintext(buf.into(), &mut self.sendable_plaintext)) } fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { - let mut sz = 0; - for buf in bufs { - sz += self - .core - .common_state - .buffer_plaintext(buf, &mut self.sendable_plaintext); - } - Ok(sz) + let payload_owner: Vec<&[u8]>; + let payload = match bufs.len() { + 0 => return Ok(0), + 1 => OutboundChunks::Single(bufs[0].deref()), + _ => { + payload_owner = bufs + .iter() + .map(|io_slice| io_slice.deref()) + .collect(); + + OutboundChunks::new(&payload_owner) + } + }; + Ok(self + .core + .common_state + .buffer_plaintext(payload, &mut self.sendable_plaintext)) } fn flush(&mut self) -> io::Result<()> { diff --git a/rustls/src/conn/unbuffered.rs b/rustls/src/conn/unbuffered.rs index 16059240..b4946dfd 100644 --- a/rustls/src/conn/unbuffered.rs +++ b/rustls/src/conn/unbuffered.rs @@ -400,7 +400,7 @@ impl WriteTraffic<'_, Data> { self.conn .core .common_state - .write_plaintext(application_data, outgoing_tls) + .write_plaintext(application_data.into(), outgoing_tls) } /// Encrypts a close_notify warning alert in `outgoing_tls` diff --git a/rustls/src/crypto/aws_lc_rs/tls12.rs b/rustls/src/crypto/aws_lc_rs/tls12.rs index e7903637..94764d7f 100644 --- a/rustls/src/crypto/aws_lc_rs/tls12.rs +++ b/rustls/src/crypto/aws_lc_rs/tls12.rs @@ -309,7 +309,7 @@ impl MessageEncrypter for GcmMessageEncrypter { let total_len = self.encrypted_payload_len(msg.payload.len()); let mut payload = Vec::with_capacity(total_len); payload.extend_from_slice(&nonce.as_ref()[4..]); - payload.extend_from_slice(msg.payload); + msg.payload.copy_to_vec(&mut payload); self.enc_key .seal_in_place_separate_tag(nonce, aad, &mut payload[GCM_EXPLICIT_NONCE_LEN..]) @@ -385,7 +385,7 @@ impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter { let total_len = self.encrypted_payload_len(msg.payload.len()); let mut buf = Vec::with_capacity(total_len); - buf.extend_from_slice(msg.payload); + msg.payload.copy_to_vec(&mut buf); self.enc_key .seal_in_place_append_tag(nonce, aad, &mut buf) diff --git a/rustls/src/crypto/aws_lc_rs/tls13.rs b/rustls/src/crypto/aws_lc_rs/tls13.rs index d6dacbea..abcdf3a5 100644 --- a/rustls/src/crypto/aws_lc_rs/tls13.rs +++ b/rustls/src/crypto/aws_lc_rs/tls13.rs @@ -223,7 +223,7 @@ impl MessageEncrypter for AeadMessageEncrypter { fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result { let total_len = self.encrypted_payload_len(msg.payload.len()); let mut payload = Vec::with_capacity(total_len); - payload.extend_from_slice(msg.payload); + msg.payload.copy_to_vec(&mut payload); msg.typ.encode(&mut payload); let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, seq).0); @@ -279,7 +279,7 @@ impl MessageEncrypter for GcmMessageEncrypter { fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result { let total_len = msg.payload.len() + 1 + self.enc_key.algorithm().tag_len(); let mut payload = Vec::with_capacity(total_len); - payload.extend_from_slice(msg.payload); + msg.payload.copy_to_vec(&mut payload); msg.typ.encode(&mut payload); let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, seq).0); diff --git a/rustls/src/crypto/ring/tls12.rs b/rustls/src/crypto/ring/tls12.rs index bff49cce..a01da6ef 100644 --- a/rustls/src/crypto/ring/tls12.rs +++ b/rustls/src/crypto/ring/tls12.rs @@ -293,7 +293,7 @@ impl MessageEncrypter for GcmMessageEncrypter { let total_len = self.encrypted_payload_len(msg.payload.len()); let mut payload = Vec::with_capacity(total_len); payload.extend_from_slice(&nonce.as_ref()[4..]); - payload.extend_from_slice(msg.payload); + msg.payload.copy_to_vec(&mut payload); self.enc_key .seal_in_place_separate_tag(nonce, aad, &mut payload[GCM_EXPLICIT_NONCE_LEN..]) @@ -369,7 +369,7 @@ impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter { let total_len = self.encrypted_payload_len(msg.payload.len()); let mut buf = Vec::with_capacity(total_len); - buf.extend_from_slice(msg.payload); + msg.payload.copy_to_vec(&mut buf); self.enc_key .seal_in_place_append_tag(nonce, aad, &mut buf) diff --git a/rustls/src/crypto/ring/tls13.rs b/rustls/src/crypto/ring/tls13.rs index 96510557..4876b55e 100644 --- a/rustls/src/crypto/ring/tls13.rs +++ b/rustls/src/crypto/ring/tls13.rs @@ -195,7 +195,7 @@ impl MessageEncrypter for Tls13MessageEncrypter { fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result { let total_len = self.encrypted_payload_len(msg.payload.len()); let mut payload = Vec::with_capacity(total_len); - payload.extend_from_slice(msg.payload); + msg.payload.copy_to_vec(&mut payload); msg.typ.encode(&mut payload); let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, seq).0); diff --git a/rustls/src/msgs/fragmenter.rs b/rustls/src/msgs/fragmenter.rs index 6cd72608..d9314134 100644 --- a/rustls/src/msgs/fragmenter.rs +++ b/rustls/src/msgs/fragmenter.rs @@ -1,6 +1,6 @@ use crate::enums::ContentType; use crate::enums::ProtocolVersion; -use crate::msgs::message::{OutboundMessage, PlainMessage}; +use crate::msgs::message::{OutboundChunks, OutboundMessage, PlainMessage}; use crate::Error; pub(crate) const MAX_FRAGMENT_LEN: usize = 16384; pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2; @@ -27,24 +27,22 @@ impl MessageFragmenter { &self, msg: &'a PlainMessage, ) -> impl Iterator> + 'a { - self.fragment_slice(msg.typ, msg.version, msg.payload.bytes()) + self.fragment_payload(msg.typ, msg.version, msg.payload.bytes().into()) } /// Enqueue borrowed fragments of (version, typ, payload) which /// are no longer than max_frag onto the `out` deque. - pub(crate) fn fragment_slice<'a>( + pub(crate) fn fragment_payload<'a>( &self, typ: ContentType, version: ProtocolVersion, - payload: &'a [u8], + payload: OutboundChunks<'a>, ) -> impl ExactSizeIterator> { - payload - .chunks(self.max_frag) - .map(move |payload| OutboundMessage { - typ, - version, - payload, - }) + Chunker::new(payload, self.max_frag).map(move |payload| OutboundMessage { + typ, + version, + payload, + }) } /// Set the maximum fragment size that will be produced. @@ -65,13 +63,47 @@ impl MessageFragmenter { } } +/// An iterator over borrowed fragments of a payload +struct Chunker<'a> { + payload: OutboundChunks<'a>, + limit: usize, +} + +impl<'a> Chunker<'a> { + fn new(payload: OutboundChunks<'a>, limit: usize) -> Self { + Self { payload, limit } + } +} + +impl<'a> Iterator for Chunker<'a> { + type Item = OutboundChunks<'a>; + + fn next(&mut self) -> Option { + if self.payload.is_empty() { + return None; + } + + let (before, after) = self.payload.split_at(self.limit); + self.payload = after; + Some(before) + } +} + +impl<'a> ExactSizeIterator for Chunker<'a> { + fn len(&self) -> usize { + (self.payload.len() + self.limit - 1) / self.limit + } +} + #[cfg(test)] mod tests { use super::{MessageFragmenter, PACKET_OVERHEAD}; use crate::enums::ContentType; use crate::enums::ProtocolVersion; use crate::msgs::base::Payload; - use crate::msgs::message::{BorrowedPlainMessage, OutboundMessage, PlainMessage}; + use crate::msgs::message::{ + BorrowedPlainMessage, OutboundChunks, OutboundMessage, PlainMessage, + }; fn msg_eq( m: &OutboundMessage, @@ -82,7 +114,7 @@ mod tests { ) { assert_eq!(&m.typ, typ); assert_eq!(&m.version, version); - assert_eq!(m.payload, bytes); + assert_eq!(m.payload.to_vec(), bytes); let buf = m.to_unencrypted_opaque().encode(); @@ -159,4 +191,35 @@ mod tests { b"\x01\x02\x03\x04\x05\x06\x07\x08", ); } + + #[test] + fn fragment_multiple_slices() { + let typ = ContentType::Handshake; + let version = ProtocolVersion::TLSv1_2; + let payload_owner: Vec<&[u8]> = vec![&[b'a'; 8], &[b'b'; 12], &[b'c'; 32], &[b'd'; 20]]; + let borrowed_payload = OutboundChunks::new(&payload_owner); + let mut frag = MessageFragmenter::default(); + frag.set_max_fragment_size(Some(37)) // 32 + packet overhead + .unwrap(); + + let fragments = frag + .fragment_payload(typ, version, borrowed_payload) + .collect::>(); + assert_eq!(fragments.len(), 3); + msg_eq( + &fragments[0], + 37, + &typ, + &version, + b"aaaaaaaabbbbbbbbbbbbcccccccccccc", + ); + msg_eq( + &fragments[1], + 37, + &typ, + &version, + b"ccccccccccccccccccccdddddddddddd", + ); + msg_eq(&fragments[2], 13, &typ, &version, b"dddddddd"); + } } diff --git a/rustls/src/msgs/message.rs b/rustls/src/msgs/message.rs index 9ab670eb..38d1234f 100644 --- a/rustls/src/msgs/message.rs +++ b/rustls/src/msgs/message.rs @@ -347,7 +347,7 @@ impl PlainMessage { OutboundMessage { version: self.version, typ: self.typ, - payload: self.payload.bytes(), + payload: self.payload.bytes().into(), } } } @@ -468,7 +468,7 @@ impl BorrowedPlainMessage for InboundMessage<'_> { pub struct OutboundMessage<'a> { pub typ: ContentType, pub version: ProtocolVersion, - pub payload: &'a [u8], + pub payload: OutboundChunks<'a>, } impl BorrowedPlainMessage for OutboundMessage<'_> { @@ -520,6 +520,126 @@ pub trait BorrowedPlainMessage: Sized { fn version(&self) -> ProtocolVersion; } +#[derive(Debug, Clone)] +/// A collection of borrowed plaintext slices. +/// Warning: OutboundChunks does not guarantee that the simplest variant is used. +/// Multiple can hold non fragmented or empty payloads. +pub enum OutboundChunks<'a> { + /// A single byte slice. Contrary to `Multiple`, this uses a single pointer indirection + Single(&'a [u8]), + /// A collection of chunks (byte slices) + /// and cursors to single out a fragmented range of bytes. + /// OutboundChunks assumes that start <= end + Multiple { + chunks: &'a [&'a [u8]], + start: usize, + end: usize, + }, +} + +impl<'a> OutboundChunks<'a> { + /// Create a payload from a slice of byte slices. + /// If fragmented the cursors are added by default: start = 0, end = length + pub fn new(chunks: &'a [&'a [u8]]) -> Self { + if chunks.len() == 1 { + Self::Single(chunks[0]) + } else { + Self::Multiple { + chunks, + start: 0, + end: chunks + .iter() + .map(|chunk| chunk.len()) + .sum(), + } + } + } + + /// Create a payload with a single empty slice + pub fn new_empty() -> Self { + Self::Single(&[]) + } + + /// Flatten the slice of byte slices to an owned vector of bytes + pub fn to_vec(&self) -> Vec { + let mut vec = Vec::with_capacity(self.len()); + self.copy_to_vec(&mut vec); + vec + } + + /// Append all bytes to a vector + pub fn copy_to_vec(&self, vec: &mut Vec) { + match *self { + Self::Single(chunk) => vec.extend_from_slice(chunk), + Self::Multiple { chunks, start, end } => { + let mut size = 0; + for chunk in chunks.iter() { + let psize = size; + let len = chunk.len(); + size += len; + if size <= start || psize >= end { + continue; + } + let start = if psize < start { start - psize } else { 0 }; + let end = if end - psize < len { end - psize } else { len }; + vec.extend_from_slice(&chunk[start..end]); + } + } + } + } + + /// Split self in two, around an index + /// Works similarly to `split_at` in the core library, except it doesn't panic if out of bound + pub fn split_at(&self, mid: usize) -> (Self, Self) { + match *self { + Self::Single(chunk) => { + let mid = Ord::min(mid, chunk.len()); + (Self::Single(&chunk[..mid]), Self::Single(&chunk[mid..])) + } + Self::Multiple { chunks, start, end } => { + let mid = Ord::min(start + mid, end); + ( + Self::Multiple { + chunks, + start, + end: mid, + }, + Self::Multiple { + chunks, + start: mid, + end, + }, + ) + } + } + } + + /// Returns true if the payload is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the cumulative length of all chunks + pub fn len(&self) -> usize { + match self { + Self::Single(chunk) => chunk.len(), + Self::Multiple { start, end, .. } => end - start, + } + } +} + +impl<'a> From<&'a [u8]> for OutboundChunks<'a> { + fn from(payload: &'a [u8]) -> Self { + Self::Single(payload) + } +} + +impl<'a, const N: usize> From<&'a [u8; N]> for OutboundChunks<'a> { + fn from(payload: &'a [u8; N]) -> Self { + Self::Single(payload) + } +} + #[derive(Debug)] pub enum MessageError { TooShortForHeader, @@ -529,3 +649,118 @@ pub enum MessageError { InvalidContentType, UnknownProtocolVersion, } + +#[cfg(test)] +mod tests { + use std::{println, vec}; + + use super::*; + + #[test] + fn split_at_with_single_slice() { + let owner: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7]; + let borrowed_payload = OutboundChunks::Single(owner); + + let (before, after) = borrowed_payload.split_at(6); + println!("before:{:?}\nafter:{:?}", before, after); + assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5]); + assert_eq!(after.to_vec(), &[6, 7]); + } + + #[test] + fn split_at_with_multiple_slices() { + let owner: Vec<&[u8]> = vec![&[0, 1, 2, 3], &[4, 5], &[6, 7, 8], &[9, 10, 11, 12]]; + let borrowed_payload = OutboundChunks::new(&owner); + + let (before, after) = borrowed_payload.split_at(3); + println!("before:{:?}\nafter:{:?}", before, after); + assert_eq!(before.to_vec(), &[0, 1, 2]); + assert_eq!(after.to_vec(), &[3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + + let (before, after) = borrowed_payload.split_at(8); + println!("before:{:?}\nafter:{:?}", before, after); + assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(after.to_vec(), &[8, 9, 10, 11, 12]); + + let (before, after) = borrowed_payload.split_at(11); + println!("before:{:?}\nafter:{:?}", before, after); + assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + assert_eq!(after.to_vec(), &[11, 12]); + } + + #[test] + fn split_out_of_bounds() { + let owner: Vec<&[u8]> = vec![&[0, 1, 2, 3], &[4, 5], &[6, 7, 8], &[9, 10, 11, 12]]; + + let single_payload = OutboundChunks::Single(owner[0]); + let (before, after) = single_payload.split_at(17); + println!("before:{:?}\nafter:{:?}", before, after); + assert_eq!(before.to_vec(), &[0, 1, 2, 3]); + assert!(after.is_empty()); + + let multiple_payload = OutboundChunks::new(&owner); + let (before, after) = multiple_payload.split_at(17); + println!("before:{:?}\nafter:{:?}", before, after); + assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + assert!(after.is_empty()); + + let empty_payload = OutboundChunks::new_empty(); + let (before, after) = empty_payload.split_at(17); + println!("before:{:?}\nafter:{:?}", before, after); + assert!(before.is_empty()); + assert!(after.is_empty()); + } + + #[test] + fn empty_slices_mixed() { + let owner: Vec<&[u8]> = vec![&[], &[], &[0], &[], &[1, 2], &[], &[3], &[4], &[], &[]]; + let mut borrowed_payload = OutboundChunks::new(&owner); + let mut fragment_count = 0; + let mut fragment; + let expected_fragments: &[&[u8]] = &[&[0, 1], &[2, 3], &[4]]; + + while !borrowed_payload.is_empty() { + (fragment, borrowed_payload) = borrowed_payload.split_at(2); + println!("{fragment:?}"); + assert_eq!(&expected_fragments[fragment_count], &fragment.to_vec()); + fragment_count += 1; + } + assert_eq!(fragment_count, expected_fragments.len()); + } + + #[test] + fn exhaustive_splitting() { + let owner: Vec = (0..127).collect(); + let slices = (0..7) + .map(|i| &owner[((1 << i) - 1)..((1 << (i + 1)) - 1)]) + .collect::>(); + let payload = OutboundChunks::new(&slices); + + assert_eq!(payload.to_vec(), owner); + println!("{:#?}", payload); + + for start in 0..128 { + for end in start..128 { + for mid in 0..(end - start) { + let witness = owner[start..end].split_at(mid); + let split_payload = payload + .split_at(end) + .0 + .split_at(start) + .1 + .split_at(mid); + assert_eq!( + witness.0, + split_payload.0.to_vec(), + "start: {start}, mid:{mid}, end:{end}" + ); + assert_eq!( + witness.1, + split_payload.1.to_vec(), + "start: {start}, mid:{mid}, end:{end}" + ); + } + } + } + } +} diff --git a/rustls/src/vecbuf.rs b/rustls/src/vecbuf.rs index c7b99082..55ff6a54 100644 --- a/rustls/src/vecbuf.rs +++ b/rustls/src/vecbuf.rs @@ -4,6 +4,8 @@ use core::cmp; use std::io; use std::io::Read; +use crate::msgs::message::OutboundChunks; + /// This is a byte buffer that is built from a vector /// of byte vectors. This avoids extra copies when /// appending a new byte vector, at the expense of @@ -66,9 +68,9 @@ impl ChunkVecBuffer { /// Append a copy of `bytes`, perhaps a prefix if /// we're near the limit. - pub(crate) fn append_limited_copy(&mut self, bytes: &[u8]) -> usize { - let take = self.apply_limit(bytes.len()); - self.append(bytes[..take].to_vec()); + pub(crate) fn append_limited_copy(&mut self, payload: OutboundChunks<'_>) -> usize { + let take = self.apply_limit(payload.len()); + self.append(payload.split_at(take).0.to_vec()); take } @@ -155,10 +157,10 @@ mod tests { #[test] fn short_append_copy_with_limit() { let mut cvb = ChunkVecBuffer::new(Some(12)); - assert_eq!(cvb.append_limited_copy(b"hello"), 5); - assert_eq!(cvb.append_limited_copy(b"world"), 5); - assert_eq!(cvb.append_limited_copy(b"hello"), 2); - assert_eq!(cvb.append_limited_copy(b"world"), 0); + assert_eq!(cvb.append_limited_copy(b"hello".into()), 5); + assert_eq!(cvb.append_limited_copy(b"world".into()), 5); + assert_eq!(cvb.append_limited_copy(b"hello".into()), 2); + assert_eq!(cvb.append_limited_copy(b"world".into()), 0); let mut buf = [0u8; 12]; assert_eq!(cvb.read(&mut buf).unwrap(), 12); diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index f4e3d7c1..a994f908 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -2294,6 +2294,35 @@ fn test_server_stream_read(stream_kind: StreamKind, read_kind: ReadKind) { } } +#[test] +fn test_client_write_and_vectored_write_equivalence() { + let (mut client, mut server) = make_pair(KeyType::Rsa); + do_handshake(&mut client, &mut server); + + const N: usize = 1000; + + let data_chunked: Vec = std::iter::repeat(IoSlice::new(b"A")) + .take(N) + .collect(); + let bytes_written_chunked = client + .writer() + .write_vectored(&data_chunked) + .unwrap(); + let bytes_sent_chunked = transfer(&mut client, &mut server); + println!("write_vectored returned {bytes_written_chunked} and sent {bytes_sent_chunked}"); + + let data_contiguous = &[b'A'; N]; + let bytes_written_contiguous = client + .writer() + .write(data_contiguous) + .unwrap(); + let bytes_sent_contiguous = transfer(&mut client, &mut server); + println!("write returned {bytes_written_contiguous} and sent {bytes_sent_contiguous}"); + + assert_eq!(bytes_written_chunked, bytes_written_contiguous); + assert_eq!(bytes_sent_chunked, bytes_sent_contiguous); +} + struct FailsWrites { errkind: io::ErrorKind, after: usize,