mirror of https://github.com/ctz/rustls
Create type OutboundChunks for OutboundMessage
The ConnectionCommon<T>::write_vectored was implemented by processing each chunk, fragmenting them and wrapping each fragment in a OutboundMessage before encrypting and sending it as separate TLS frames. For very fragmented payloads this generates a lot of very small payloads with most of the data being TLS headers. OutboundChunks can contain an arbitrary amount of fragmented chunks. This allows write_vectored to process all its chunks at once, fragmenting it in place if needed and wrapping it in a OutboundMessage. All the chunks are merged in a contiguous vector (taking atvantage of an already existent copy) before being encrypted and sent as a single TLS frame. Signed-off-by: Eloi DEMOLIS <eloi.demolis@clever-cloud.com> Co-Authored-By: Emmanuel Bosquet <bjokac@gmail.com>
This commit is contained in:
parent
9af53f25f9
commit
2f02ddc21b
|
@ -93,7 +93,7 @@ impl cipher::MessageEncrypter for Tls13Cipher {
|
|||
|
||||
// construct a TLSInnerPlaintext
|
||||
let mut payload = Vec::with_capacity(total_len);
|
||||
payload.extend_from_slice(m.payload);
|
||||
m.payload.copy_to_vec(&mut payload);
|
||||
payload.push(m.typ.get_u8());
|
||||
|
||||
let nonce = chacha20poly1305::Nonce::from(cipher::Nonce::new(&self.1, seq).0);
|
||||
|
@ -145,7 +145,7 @@ impl cipher::MessageEncrypter for Tls12Cipher {
|
|||
let total_len = self.encrypted_payload_len(m.payload.len());
|
||||
|
||||
let mut payload = Vec::with_capacity(total_len);
|
||||
payload.extend_from_slice(m.payload);
|
||||
m.payload.copy_to_vec(&mut payload);
|
||||
|
||||
let nonce = chacha20poly1305::Nonce::from(cipher::Nonce::new(&self.1, seq).0);
|
||||
let aad = cipher::make_tls12_aad(seq, m.typ, m.version, payload.len());
|
||||
|
|
|
@ -787,7 +787,7 @@ impl MayEncryptEarlyData<'_> {
|
|||
self.conn
|
||||
.core
|
||||
.common_state
|
||||
.write_plaintext(&early_data[..allowed], outgoing_tls)
|
||||
.write_plaintext(early_data[..allowed].into(), outgoing_tls)
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,8 @@ use crate::msgs::enums::{AlertLevel, KeyUpdateRequest};
|
|||
use crate::msgs::fragmenter::MessageFragmenter;
|
||||
use crate::msgs::handshake::CertificateChain;
|
||||
use crate::msgs::message::{
|
||||
BorrowedPlainMessage, Message, MessagePayload, OpaqueMessage, OutboundMessage, PlainMessage,
|
||||
BorrowedPlainMessage, Message, MessagePayload, OpaqueMessage, OutboundChunks, OutboundMessage,
|
||||
PlainMessage,
|
||||
};
|
||||
use crate::quic;
|
||||
use crate::record_layer;
|
||||
|
@ -188,27 +189,29 @@ impl CommonState {
|
|||
/// all the data.
|
||||
pub(crate) fn buffer_plaintext(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
payload: OutboundChunks<'_>,
|
||||
sendable_plaintext: &mut ChunkVecBuffer,
|
||||
) -> usize {
|
||||
self.perhaps_write_key_update();
|
||||
self.send_plain(data, Limit::Yes, sendable_plaintext)
|
||||
self.send_plain(payload, Limit::Yes, sendable_plaintext)
|
||||
}
|
||||
|
||||
pub(crate) fn write_plaintext(
|
||||
&mut self,
|
||||
plaintext: &[u8],
|
||||
payload: OutboundChunks<'_>,
|
||||
outgoing_tls: &mut [u8],
|
||||
) -> Result<usize, EncryptError> {
|
||||
if plaintext.is_empty() {
|
||||
if payload.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let fragments = self.message_fragmenter.fragment_slice(
|
||||
ContentType::ApplicationData,
|
||||
ProtocolVersion::TLSv1_2,
|
||||
plaintext,
|
||||
);
|
||||
let fragments = self
|
||||
.message_fragmenter
|
||||
.fragment_payload(
|
||||
ContentType::ApplicationData,
|
||||
ProtocolVersion::TLSv1_2,
|
||||
payload.clone(),
|
||||
);
|
||||
|
||||
let remaining_encryptions = self
|
||||
.record_layer
|
||||
|
@ -226,11 +229,13 @@ impl CommonState {
|
|||
fragments,
|
||||
)?;
|
||||
|
||||
let fragments = self.message_fragmenter.fragment_slice(
|
||||
ContentType::ApplicationData,
|
||||
ProtocolVersion::TLSv1_2,
|
||||
plaintext,
|
||||
);
|
||||
let fragments = self
|
||||
.message_fragmenter
|
||||
.fragment_payload(
|
||||
ContentType::ApplicationData,
|
||||
ProtocolVersion::TLSv1_2,
|
||||
payload,
|
||||
);
|
||||
|
||||
let opt_msg = self.queued_key_update_message.take();
|
||||
let written = self.write_fragments(outgoing_tls, opt_msg, fragments);
|
||||
|
@ -247,7 +252,7 @@ impl CommonState {
|
|||
return 0;
|
||||
}
|
||||
|
||||
self.send_appdata_encrypt(data, Limit::Yes)
|
||||
self.send_appdata_encrypt(data.into(), Limit::Yes)
|
||||
}
|
||||
|
||||
// Changing the keys must not span any fragmented handshake
|
||||
|
@ -277,7 +282,7 @@ impl CommonState {
|
|||
}
|
||||
|
||||
/// Like send_msg_encrypt, but operate on an appdata directly.
|
||||
fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
|
||||
fn send_appdata_encrypt(&mut self, payload: OutboundChunks<'_>, limit: Limit) -> usize {
|
||||
// Here, the limit on sendable_tls applies to encrypted data,
|
||||
// but we're respecting it for plaintext data -- so we'll
|
||||
// be out by whatever the cipher+record overhead is. That's a
|
||||
|
@ -289,11 +294,13 @@ impl CommonState {
|
|||
Limit::No => payload.len(),
|
||||
};
|
||||
|
||||
let iter = self.message_fragmenter.fragment_slice(
|
||||
ContentType::ApplicationData,
|
||||
ProtocolVersion::TLSv1_2,
|
||||
&payload[..len],
|
||||
);
|
||||
let iter = self
|
||||
.message_fragmenter
|
||||
.fragment_payload(
|
||||
ContentType::ApplicationData,
|
||||
ProtocolVersion::TLSv1_2,
|
||||
payload.split_at(len).0,
|
||||
);
|
||||
for m in iter {
|
||||
self.send_single_fragment(m);
|
||||
}
|
||||
|
@ -328,7 +335,7 @@ impl CommonState {
|
|||
/// be less than `data.len()` if buffer limits were exceeded.
|
||||
fn send_plain(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
payload: OutboundChunks<'_>,
|
||||
limit: Limit,
|
||||
sendable_plaintext: &mut ChunkVecBuffer,
|
||||
) -> usize {
|
||||
|
@ -336,25 +343,25 @@ impl CommonState {
|
|||
// If we haven't completed handshaking, buffer
|
||||
// plaintext to send once we do.
|
||||
let len = match limit {
|
||||
Limit::Yes => sendable_plaintext.append_limited_copy(data),
|
||||
Limit::No => sendable_plaintext.append(data.to_vec()),
|
||||
Limit::Yes => sendable_plaintext.append_limited_copy(payload),
|
||||
Limit::No => sendable_plaintext.append(payload.to_vec()),
|
||||
};
|
||||
return len;
|
||||
}
|
||||
|
||||
self.send_plain_non_buffering(data, limit)
|
||||
self.send_plain_non_buffering(payload, limit)
|
||||
}
|
||||
|
||||
fn send_plain_non_buffering(&mut self, data: &[u8], limit: Limit) -> usize {
|
||||
fn send_plain_non_buffering(&mut self, payload: OutboundChunks<'_>, limit: Limit) -> usize {
|
||||
debug_assert!(self.may_send_application_data);
|
||||
debug_assert!(self.record_layer.is_encrypting());
|
||||
|
||||
if data.is_empty() {
|
||||
if payload.is_empty() {
|
||||
// Don't send empty fragments.
|
||||
return 0;
|
||||
}
|
||||
|
||||
self.send_appdata_encrypt(data, limit)
|
||||
self.send_appdata_encrypt(payload, limit)
|
||||
}
|
||||
|
||||
/// Mark the connection as ready to send application data.
|
||||
|
@ -386,7 +393,7 @@ impl CommonState {
|
|||
}
|
||||
|
||||
while let Some(buf) = sendable_plaintext.pop() {
|
||||
self.send_plain_non_buffering(&buf, Limit::No);
|
||||
self.send_plain_non_buffering(buf.as_slice().into(), Limit::No);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,11 +5,12 @@ use crate::error::{Error, PeerMisbehaved};
|
|||
use crate::log::trace;
|
||||
use crate::msgs::deframer::{Deframed, DeframerSliceBuffer, DeframerVecBuffer, MessageDeframer};
|
||||
use crate::msgs::handshake::Random;
|
||||
use crate::msgs::message::{InboundMessage, Message, MessagePayload};
|
||||
use crate::msgs::message::{InboundMessage, Message, MessagePayload, OutboundChunks};
|
||||
use crate::suites::{ExtractedSecrets, PartiallyExtractedSecrets};
|
||||
use crate::vecbuf::ChunkVecBuffer;
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt::Debug;
|
||||
use core::mem;
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
@ -257,18 +258,27 @@ impl<T> PlaintextSink for ConnectionCommon<T> {
|
|||
Ok(self
|
||||
.core
|
||||
.common_state
|
||||
.buffer_plaintext(buf, &mut self.sendable_plaintext))
|
||||
.buffer_plaintext(buf.into(), &mut self.sendable_plaintext))
|
||||
}
|
||||
|
||||
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
|
||||
let mut sz = 0;
|
||||
for buf in bufs {
|
||||
sz += self
|
||||
.core
|
||||
.common_state
|
||||
.buffer_plaintext(buf, &mut self.sendable_plaintext);
|
||||
}
|
||||
Ok(sz)
|
||||
let payload_owner: Vec<&[u8]>;
|
||||
let payload = match bufs.len() {
|
||||
0 => return Ok(0),
|
||||
1 => OutboundChunks::Single(bufs[0].deref()),
|
||||
_ => {
|
||||
payload_owner = bufs
|
||||
.iter()
|
||||
.map(|io_slice| io_slice.deref())
|
||||
.collect();
|
||||
|
||||
OutboundChunks::new(&payload_owner)
|
||||
}
|
||||
};
|
||||
Ok(self
|
||||
.core
|
||||
.common_state
|
||||
.buffer_plaintext(payload, &mut self.sendable_plaintext))
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
|
|
|
@ -400,7 +400,7 @@ impl<Data> WriteTraffic<'_, Data> {
|
|||
self.conn
|
||||
.core
|
||||
.common_state
|
||||
.write_plaintext(application_data, outgoing_tls)
|
||||
.write_plaintext(application_data.into(), outgoing_tls)
|
||||
}
|
||||
|
||||
/// Encrypts a close_notify warning alert in `outgoing_tls`
|
||||
|
|
|
@ -309,7 +309,7 @@ impl MessageEncrypter for GcmMessageEncrypter {
|
|||
let total_len = self.encrypted_payload_len(msg.payload.len());
|
||||
let mut payload = Vec::with_capacity(total_len);
|
||||
payload.extend_from_slice(&nonce.as_ref()[4..]);
|
||||
payload.extend_from_slice(msg.payload);
|
||||
msg.payload.copy_to_vec(&mut payload);
|
||||
|
||||
self.enc_key
|
||||
.seal_in_place_separate_tag(nonce, aad, &mut payload[GCM_EXPLICIT_NONCE_LEN..])
|
||||
|
@ -385,7 +385,7 @@ impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter {
|
|||
|
||||
let total_len = self.encrypted_payload_len(msg.payload.len());
|
||||
let mut buf = Vec::with_capacity(total_len);
|
||||
buf.extend_from_slice(msg.payload);
|
||||
msg.payload.copy_to_vec(&mut buf);
|
||||
|
||||
self.enc_key
|
||||
.seal_in_place_append_tag(nonce, aad, &mut buf)
|
||||
|
|
|
@ -223,7 +223,7 @@ impl MessageEncrypter for AeadMessageEncrypter {
|
|||
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
let total_len = self.encrypted_payload_len(msg.payload.len());
|
||||
let mut payload = Vec::with_capacity(total_len);
|
||||
payload.extend_from_slice(msg.payload);
|
||||
msg.payload.copy_to_vec(&mut payload);
|
||||
msg.typ.encode(&mut payload);
|
||||
|
||||
let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, seq).0);
|
||||
|
@ -279,7 +279,7 @@ impl MessageEncrypter for GcmMessageEncrypter {
|
|||
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
let total_len = msg.payload.len() + 1 + self.enc_key.algorithm().tag_len();
|
||||
let mut payload = Vec::with_capacity(total_len);
|
||||
payload.extend_from_slice(msg.payload);
|
||||
msg.payload.copy_to_vec(&mut payload);
|
||||
msg.typ.encode(&mut payload);
|
||||
|
||||
let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, seq).0);
|
||||
|
|
|
@ -293,7 +293,7 @@ impl MessageEncrypter for GcmMessageEncrypter {
|
|||
let total_len = self.encrypted_payload_len(msg.payload.len());
|
||||
let mut payload = Vec::with_capacity(total_len);
|
||||
payload.extend_from_slice(&nonce.as_ref()[4..]);
|
||||
payload.extend_from_slice(msg.payload);
|
||||
msg.payload.copy_to_vec(&mut payload);
|
||||
|
||||
self.enc_key
|
||||
.seal_in_place_separate_tag(nonce, aad, &mut payload[GCM_EXPLICIT_NONCE_LEN..])
|
||||
|
@ -369,7 +369,7 @@ impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter {
|
|||
|
||||
let total_len = self.encrypted_payload_len(msg.payload.len());
|
||||
let mut buf = Vec::with_capacity(total_len);
|
||||
buf.extend_from_slice(msg.payload);
|
||||
msg.payload.copy_to_vec(&mut buf);
|
||||
|
||||
self.enc_key
|
||||
.seal_in_place_append_tag(nonce, aad, &mut buf)
|
||||
|
|
|
@ -195,7 +195,7 @@ impl MessageEncrypter for Tls13MessageEncrypter {
|
|||
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
|
||||
let total_len = self.encrypted_payload_len(msg.payload.len());
|
||||
let mut payload = Vec::with_capacity(total_len);
|
||||
payload.extend_from_slice(msg.payload);
|
||||
msg.payload.copy_to_vec(&mut payload);
|
||||
msg.typ.encode(&mut payload);
|
||||
|
||||
let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, seq).0);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::enums::ContentType;
|
||||
use crate::enums::ProtocolVersion;
|
||||
use crate::msgs::message::{OutboundMessage, PlainMessage};
|
||||
use crate::msgs::message::{OutboundChunks, OutboundMessage, PlainMessage};
|
||||
use crate::Error;
|
||||
pub(crate) const MAX_FRAGMENT_LEN: usize = 16384;
|
||||
pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2;
|
||||
|
@ -27,24 +27,22 @@ impl MessageFragmenter {
|
|||
&self,
|
||||
msg: &'a PlainMessage,
|
||||
) -> impl Iterator<Item = OutboundMessage<'a>> + 'a {
|
||||
self.fragment_slice(msg.typ, msg.version, msg.payload.bytes())
|
||||
self.fragment_payload(msg.typ, msg.version, msg.payload.bytes().into())
|
||||
}
|
||||
|
||||
/// Enqueue borrowed fragments of (version, typ, payload) which
|
||||
/// are no longer than max_frag onto the `out` deque.
|
||||
pub(crate) fn fragment_slice<'a>(
|
||||
pub(crate) fn fragment_payload<'a>(
|
||||
&self,
|
||||
typ: ContentType,
|
||||
version: ProtocolVersion,
|
||||
payload: &'a [u8],
|
||||
payload: OutboundChunks<'a>,
|
||||
) -> impl ExactSizeIterator<Item = OutboundMessage<'a>> {
|
||||
payload
|
||||
.chunks(self.max_frag)
|
||||
.map(move |payload| OutboundMessage {
|
||||
typ,
|
||||
version,
|
||||
payload,
|
||||
})
|
||||
Chunker::new(payload, self.max_frag).map(move |payload| OutboundMessage {
|
||||
typ,
|
||||
version,
|
||||
payload,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set the maximum fragment size that will be produced.
|
||||
|
@ -65,13 +63,47 @@ impl MessageFragmenter {
|
|||
}
|
||||
}
|
||||
|
||||
/// An iterator over borrowed fragments of a payload
|
||||
struct Chunker<'a> {
|
||||
payload: OutboundChunks<'a>,
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
impl<'a> Chunker<'a> {
|
||||
fn new(payload: OutboundChunks<'a>, limit: usize) -> Self {
|
||||
Self { payload, limit }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for Chunker<'a> {
|
||||
type Item = OutboundChunks<'a>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.payload.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (before, after) = self.payload.split_at(self.limit);
|
||||
self.payload = after;
|
||||
Some(before)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ExactSizeIterator for Chunker<'a> {
|
||||
fn len(&self) -> usize {
|
||||
(self.payload.len() + self.limit - 1) / self.limit
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{MessageFragmenter, PACKET_OVERHEAD};
|
||||
use crate::enums::ContentType;
|
||||
use crate::enums::ProtocolVersion;
|
||||
use crate::msgs::base::Payload;
|
||||
use crate::msgs::message::{BorrowedPlainMessage, OutboundMessage, PlainMessage};
|
||||
use crate::msgs::message::{
|
||||
BorrowedPlainMessage, OutboundChunks, OutboundMessage, PlainMessage,
|
||||
};
|
||||
|
||||
fn msg_eq(
|
||||
m: &OutboundMessage,
|
||||
|
@ -82,7 +114,7 @@ mod tests {
|
|||
) {
|
||||
assert_eq!(&m.typ, typ);
|
||||
assert_eq!(&m.version, version);
|
||||
assert_eq!(m.payload, bytes);
|
||||
assert_eq!(m.payload.to_vec(), bytes);
|
||||
|
||||
let buf = m.to_unencrypted_opaque().encode();
|
||||
|
||||
|
@ -159,4 +191,35 @@ mod tests {
|
|||
b"\x01\x02\x03\x04\x05\x06\x07\x08",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fragment_multiple_slices() {
|
||||
let typ = ContentType::Handshake;
|
||||
let version = ProtocolVersion::TLSv1_2;
|
||||
let payload_owner: Vec<&[u8]> = vec![&[b'a'; 8], &[b'b'; 12], &[b'c'; 32], &[b'd'; 20]];
|
||||
let borrowed_payload = OutboundChunks::new(&payload_owner);
|
||||
let mut frag = MessageFragmenter::default();
|
||||
frag.set_max_fragment_size(Some(37)) // 32 + packet overhead
|
||||
.unwrap();
|
||||
|
||||
let fragments = frag
|
||||
.fragment_payload(typ, version, borrowed_payload)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(fragments.len(), 3);
|
||||
msg_eq(
|
||||
&fragments[0],
|
||||
37,
|
||||
&typ,
|
||||
&version,
|
||||
b"aaaaaaaabbbbbbbbbbbbcccccccccccc",
|
||||
);
|
||||
msg_eq(
|
||||
&fragments[1],
|
||||
37,
|
||||
&typ,
|
||||
&version,
|
||||
b"ccccccccccccccccccccdddddddddddd",
|
||||
);
|
||||
msg_eq(&fragments[2], 13, &typ, &version, b"dddddddd");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -347,7 +347,7 @@ impl PlainMessage {
|
|||
OutboundMessage {
|
||||
version: self.version,
|
||||
typ: self.typ,
|
||||
payload: self.payload.bytes(),
|
||||
payload: self.payload.bytes().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -468,7 +468,7 @@ impl BorrowedPlainMessage for InboundMessage<'_> {
|
|||
pub struct OutboundMessage<'a> {
|
||||
pub typ: ContentType,
|
||||
pub version: ProtocolVersion,
|
||||
pub payload: &'a [u8],
|
||||
pub payload: OutboundChunks<'a>,
|
||||
}
|
||||
|
||||
impl BorrowedPlainMessage for OutboundMessage<'_> {
|
||||
|
@ -520,6 +520,126 @@ pub trait BorrowedPlainMessage: Sized {
|
|||
fn version(&self) -> ProtocolVersion;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// A collection of borrowed plaintext slices.
|
||||
/// Warning: OutboundChunks does not guarantee that the simplest variant is used.
|
||||
/// Multiple can hold non fragmented or empty payloads.
|
||||
pub enum OutboundChunks<'a> {
|
||||
/// A single byte slice. Contrary to `Multiple`, this uses a single pointer indirection
|
||||
Single(&'a [u8]),
|
||||
/// A collection of chunks (byte slices)
|
||||
/// and cursors to single out a fragmented range of bytes.
|
||||
/// OutboundChunks assumes that start <= end
|
||||
Multiple {
|
||||
chunks: &'a [&'a [u8]],
|
||||
start: usize,
|
||||
end: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a> OutboundChunks<'a> {
|
||||
/// Create a payload from a slice of byte slices.
|
||||
/// If fragmented the cursors are added by default: start = 0, end = length
|
||||
pub fn new(chunks: &'a [&'a [u8]]) -> Self {
|
||||
if chunks.len() == 1 {
|
||||
Self::Single(chunks[0])
|
||||
} else {
|
||||
Self::Multiple {
|
||||
chunks,
|
||||
start: 0,
|
||||
end: chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.len())
|
||||
.sum(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a payload with a single empty slice
|
||||
pub fn new_empty() -> Self {
|
||||
Self::Single(&[])
|
||||
}
|
||||
|
||||
/// Flatten the slice of byte slices to an owned vector of bytes
|
||||
pub fn to_vec(&self) -> Vec<u8> {
|
||||
let mut vec = Vec::with_capacity(self.len());
|
||||
self.copy_to_vec(&mut vec);
|
||||
vec
|
||||
}
|
||||
|
||||
/// Append all bytes to a vector
|
||||
pub fn copy_to_vec(&self, vec: &mut Vec<u8>) {
|
||||
match *self {
|
||||
Self::Single(chunk) => vec.extend_from_slice(chunk),
|
||||
Self::Multiple { chunks, start, end } => {
|
||||
let mut size = 0;
|
||||
for chunk in chunks.iter() {
|
||||
let psize = size;
|
||||
let len = chunk.len();
|
||||
size += len;
|
||||
if size <= start || psize >= end {
|
||||
continue;
|
||||
}
|
||||
let start = if psize < start { start - psize } else { 0 };
|
||||
let end = if end - psize < len { end - psize } else { len };
|
||||
vec.extend_from_slice(&chunk[start..end]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Split self in two, around an index
|
||||
/// Works similarly to `split_at` in the core library, except it doesn't panic if out of bound
|
||||
pub fn split_at(&self, mid: usize) -> (Self, Self) {
|
||||
match *self {
|
||||
Self::Single(chunk) => {
|
||||
let mid = Ord::min(mid, chunk.len());
|
||||
(Self::Single(&chunk[..mid]), Self::Single(&chunk[mid..]))
|
||||
}
|
||||
Self::Multiple { chunks, start, end } => {
|
||||
let mid = Ord::min(start + mid, end);
|
||||
(
|
||||
Self::Multiple {
|
||||
chunks,
|
||||
start,
|
||||
end: mid,
|
||||
},
|
||||
Self::Multiple {
|
||||
chunks,
|
||||
start: mid,
|
||||
end,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the payload is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Returns the cumulative length of all chunks
|
||||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
Self::Single(chunk) => chunk.len(),
|
||||
Self::Multiple { start, end, .. } => end - start,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a [u8]> for OutboundChunks<'a> {
|
||||
fn from(payload: &'a [u8]) -> Self {
|
||||
Self::Single(payload)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const N: usize> From<&'a [u8; N]> for OutboundChunks<'a> {
|
||||
fn from(payload: &'a [u8; N]) -> Self {
|
||||
Self::Single(payload)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum MessageError {
|
||||
TooShortForHeader,
|
||||
|
@ -529,3 +649,118 @@ pub enum MessageError {
|
|||
InvalidContentType,
|
||||
UnknownProtocolVersion,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{println, vec};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn split_at_with_single_slice() {
|
||||
let owner: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7];
|
||||
let borrowed_payload = OutboundChunks::Single(owner);
|
||||
|
||||
let (before, after) = borrowed_payload.split_at(6);
|
||||
println!("before:{:?}\nafter:{:?}", before, after);
|
||||
assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(after.to_vec(), &[6, 7]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_at_with_multiple_slices() {
|
||||
let owner: Vec<&[u8]> = vec![&[0, 1, 2, 3], &[4, 5], &[6, 7, 8], &[9, 10, 11, 12]];
|
||||
let borrowed_payload = OutboundChunks::new(&owner);
|
||||
|
||||
let (before, after) = borrowed_payload.split_at(3);
|
||||
println!("before:{:?}\nafter:{:?}", before, after);
|
||||
assert_eq!(before.to_vec(), &[0, 1, 2]);
|
||||
assert_eq!(after.to_vec(), &[3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
|
||||
|
||||
let (before, after) = borrowed_payload.split_at(8);
|
||||
println!("before:{:?}\nafter:{:?}", before, after);
|
||||
assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7]);
|
||||
assert_eq!(after.to_vec(), &[8, 9, 10, 11, 12]);
|
||||
|
||||
let (before, after) = borrowed_payload.split_at(11);
|
||||
println!("before:{:?}\nafter:{:?}", before, after);
|
||||
assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
|
||||
assert_eq!(after.to_vec(), &[11, 12]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_out_of_bounds() {
|
||||
let owner: Vec<&[u8]> = vec![&[0, 1, 2, 3], &[4, 5], &[6, 7, 8], &[9, 10, 11, 12]];
|
||||
|
||||
let single_payload = OutboundChunks::Single(owner[0]);
|
||||
let (before, after) = single_payload.split_at(17);
|
||||
println!("before:{:?}\nafter:{:?}", before, after);
|
||||
assert_eq!(before.to_vec(), &[0, 1, 2, 3]);
|
||||
assert!(after.is_empty());
|
||||
|
||||
let multiple_payload = OutboundChunks::new(&owner);
|
||||
let (before, after) = multiple_payload.split_at(17);
|
||||
println!("before:{:?}\nafter:{:?}", before, after);
|
||||
assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
|
||||
assert!(after.is_empty());
|
||||
|
||||
let empty_payload = OutboundChunks::new_empty();
|
||||
let (before, after) = empty_payload.split_at(17);
|
||||
println!("before:{:?}\nafter:{:?}", before, after);
|
||||
assert!(before.is_empty());
|
||||
assert!(after.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_slices_mixed() {
|
||||
let owner: Vec<&[u8]> = vec![&[], &[], &[0], &[], &[1, 2], &[], &[3], &[4], &[], &[]];
|
||||
let mut borrowed_payload = OutboundChunks::new(&owner);
|
||||
let mut fragment_count = 0;
|
||||
let mut fragment;
|
||||
let expected_fragments: &[&[u8]] = &[&[0, 1], &[2, 3], &[4]];
|
||||
|
||||
while !borrowed_payload.is_empty() {
|
||||
(fragment, borrowed_payload) = borrowed_payload.split_at(2);
|
||||
println!("{fragment:?}");
|
||||
assert_eq!(&expected_fragments[fragment_count], &fragment.to_vec());
|
||||
fragment_count += 1;
|
||||
}
|
||||
assert_eq!(fragment_count, expected_fragments.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exhaustive_splitting() {
|
||||
let owner: Vec<u8> = (0..127).collect();
|
||||
let slices = (0..7)
|
||||
.map(|i| &owner[((1 << i) - 1)..((1 << (i + 1)) - 1)])
|
||||
.collect::<Vec<_>>();
|
||||
let payload = OutboundChunks::new(&slices);
|
||||
|
||||
assert_eq!(payload.to_vec(), owner);
|
||||
println!("{:#?}", payload);
|
||||
|
||||
for start in 0..128 {
|
||||
for end in start..128 {
|
||||
for mid in 0..(end - start) {
|
||||
let witness = owner[start..end].split_at(mid);
|
||||
let split_payload = payload
|
||||
.split_at(end)
|
||||
.0
|
||||
.split_at(start)
|
||||
.1
|
||||
.split_at(mid);
|
||||
assert_eq!(
|
||||
witness.0,
|
||||
split_payload.0.to_vec(),
|
||||
"start: {start}, mid:{mid}, end:{end}"
|
||||
);
|
||||
assert_eq!(
|
||||
witness.1,
|
||||
split_payload.1.to_vec(),
|
||||
"start: {start}, mid:{mid}, end:{end}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ use core::cmp;
|
|||
use std::io;
|
||||
use std::io::Read;
|
||||
|
||||
use crate::msgs::message::OutboundChunks;
|
||||
|
||||
/// This is a byte buffer that is built from a vector
|
||||
/// of byte vectors. This avoids extra copies when
|
||||
/// appending a new byte vector, at the expense of
|
||||
|
@ -66,9 +68,9 @@ impl ChunkVecBuffer {
|
|||
|
||||
/// Append a copy of `bytes`, perhaps a prefix if
|
||||
/// we're near the limit.
|
||||
pub(crate) fn append_limited_copy(&mut self, bytes: &[u8]) -> usize {
|
||||
let take = self.apply_limit(bytes.len());
|
||||
self.append(bytes[..take].to_vec());
|
||||
pub(crate) fn append_limited_copy(&mut self, payload: OutboundChunks<'_>) -> usize {
|
||||
let take = self.apply_limit(payload.len());
|
||||
self.append(payload.split_at(take).0.to_vec());
|
||||
take
|
||||
}
|
||||
|
||||
|
@ -155,10 +157,10 @@ mod tests {
|
|||
#[test]
|
||||
fn short_append_copy_with_limit() {
|
||||
let mut cvb = ChunkVecBuffer::new(Some(12));
|
||||
assert_eq!(cvb.append_limited_copy(b"hello"), 5);
|
||||
assert_eq!(cvb.append_limited_copy(b"world"), 5);
|
||||
assert_eq!(cvb.append_limited_copy(b"hello"), 2);
|
||||
assert_eq!(cvb.append_limited_copy(b"world"), 0);
|
||||
assert_eq!(cvb.append_limited_copy(b"hello".into()), 5);
|
||||
assert_eq!(cvb.append_limited_copy(b"world".into()), 5);
|
||||
assert_eq!(cvb.append_limited_copy(b"hello".into()), 2);
|
||||
assert_eq!(cvb.append_limited_copy(b"world".into()), 0);
|
||||
|
||||
let mut buf = [0u8; 12];
|
||||
assert_eq!(cvb.read(&mut buf).unwrap(), 12);
|
||||
|
|
|
@ -2294,6 +2294,35 @@ fn test_server_stream_read(stream_kind: StreamKind, read_kind: ReadKind) {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_write_and_vectored_write_equivalence() {
|
||||
let (mut client, mut server) = make_pair(KeyType::Rsa);
|
||||
do_handshake(&mut client, &mut server);
|
||||
|
||||
const N: usize = 1000;
|
||||
|
||||
let data_chunked: Vec<IoSlice> = std::iter::repeat(IoSlice::new(b"A"))
|
||||
.take(N)
|
||||
.collect();
|
||||
let bytes_written_chunked = client
|
||||
.writer()
|
||||
.write_vectored(&data_chunked)
|
||||
.unwrap();
|
||||
let bytes_sent_chunked = transfer(&mut client, &mut server);
|
||||
println!("write_vectored returned {bytes_written_chunked} and sent {bytes_sent_chunked}");
|
||||
|
||||
let data_contiguous = &[b'A'; N];
|
||||
let bytes_written_contiguous = client
|
||||
.writer()
|
||||
.write(data_contiguous)
|
||||
.unwrap();
|
||||
let bytes_sent_contiguous = transfer(&mut client, &mut server);
|
||||
println!("write returned {bytes_written_contiguous} and sent {bytes_sent_contiguous}");
|
||||
|
||||
assert_eq!(bytes_written_chunked, bytes_written_contiguous);
|
||||
assert_eq!(bytes_sent_chunked, bytes_sent_contiguous);
|
||||
}
|
||||
|
||||
struct FailsWrites {
|
||||
errkind: io::ErrorKind,
|
||||
after: usize,
|
||||
|
|
Loading…
Reference in New Issue