extract out MessageDeframer buffer

This commit is contained in:
Jorge Aparicio 2023-11-10 18:58:29 +01:00 committed by Daniel McCarney
parent 838304ad5e
commit 20b8daecca
1 changed files with 110 additions and 40 deletions

View File

@ -1,5 +1,6 @@
use alloc::vec::Vec;
use core::ops::Range;
use core::slice::SliceIndex;
use std::io;
use super::base::Payload;
@ -22,16 +23,10 @@ pub struct MessageDeframer {
/// the deframer cannot recover.
last_error: Option<Error>,
/// Buffer of data read from the socket, in the process of being parsed into messages.
///
/// For buffer size management, checkout out the `read()` method.
buf: Vec<u8>,
/// If we're in the middle of joining a handshake payload, this is the metadata.
joining_hs: Option<HandshakePayloadMeta>,
/// What size prefix of `buf` is used.
used: usize,
buffer: DeframerVecBuffer,
}
impl MessageDeframer {
@ -47,7 +42,7 @@ impl MessageDeframer {
) -> Result<Option<Deframed>, Error> {
if let Some(last_err) = self.last_error.clone() {
return Err(last_err);
} else if self.used == 0 {
} else if self.buffer.is_empty() {
return Ok(None);
}
@ -72,7 +67,7 @@ impl MessageDeframer {
// Does our `buf` contain a full message? It does if it is big enough to
// contain a header, and that header has a length which falls within `buf`.
// If so, deframe it and place the message onto the frames output queue.
let mut rd = codec::Reader::init(&self.buf[start..self.used]);
let mut rd = codec::Reader::init(self.buffer.filled_get(start..));
let m = match OpaqueMessage::read(&mut rd) {
Ok(m) => m,
Err(msg_err) => {
@ -116,7 +111,7 @@ impl MessageDeframer {
};
if self.joining_hs.is_none() && allowed_plaintext {
// This is unencrypted. We check the contents later.
self.discard(end);
self.buffer.discard(end);
return Ok(Some(Deframed {
want_close_before_decrypt: false,
aligned: true,
@ -143,7 +138,7 @@ impl MessageDeframer {
));
}
Ok(None) => {
self.discard(end);
self.buffer.discard(end);
continue;
}
Err(e) => return Err(e),
@ -160,7 +155,7 @@ impl MessageDeframer {
// If it's not a handshake message, just return it -- no joining necessary.
if msg.typ != ContentType::Handshake {
let end = start + rd.used();
self.discard(end);
self.buffer.discard(end);
return Ok(Some(Deframed {
want_close_before_decrypt: false,
aligned: true,
@ -184,7 +179,10 @@ impl MessageDeframer {
let message = PlainMessage {
typ: ContentType::Handshake,
version: meta.version,
payload: Payload::new(&self.buf[meta.payload.start..meta.payload.start + expected_len]),
payload: Payload::new(
self.buffer
.filled_get(meta.payload.start..meta.payload.start + expected_len),
),
};
// But before we return, update the `joining_hs` state to skip past this payload.
@ -193,13 +191,16 @@ impl MessageDeframer {
// the payload start to point past the payload we're about to yield, and update the
// `expected_len` to match the state of that remaining payload.
meta.payload.start += expected_len;
meta.expected_len = payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
meta.expected_len = payload_size(
self.buffer
.filled_get(meta.payload.start..meta.payload.end),
)?;
} else {
// Otherwise, we've yielded the last handshake payload in the buffer, so we can
// discard all of the bytes that we're previously buffered as handshake data.
let end = meta.message.end;
self.joining_hs = None;
self.discard(end);
self.buffer.discard(end);
}
Ok(Some(Deframed {
@ -221,17 +222,19 @@ impl MessageDeframer {
/// Allow pushing handshake messages directly into the buffer.
pub(crate) fn push(&mut self, version: ProtocolVersion, payload: &[u8]) -> Result<(), Error> {
if self.used > 0 && self.joining_hs.is_none() {
if !self.buffer.is_empty() && self.joining_hs.is_none() {
return Err(Error::General(
"cannot push QUIC messages into unrelated connection".into(),
));
} else if let Err(err) = self.prepare_read() {
} else if let Err(err) = self
.buffer
.prepare_read(self.joining_hs.is_some())
{
return Err(Error::General(err.into()));
}
let end = self.used + payload.len();
let end = self.buffer.len() + payload.len();
self.append_hs(version, payload, end, true)?;
self.used = end;
Ok(())
}
@ -252,15 +255,17 @@ impl MessageDeframer {
// We're joining a handshake message to the previous one here.
// Write it into the buffer and update the metadata.
let dst = &mut self.buf[meta.payload.end..meta.payload.end + payload.len()];
dst.copy_from_slice(payload);
self.buffer
.copy(payload, meta.payload.end, quic);
meta.message.end = end;
meta.payload.end += payload.len();
// If we haven't parsed the payload size yet, try to do so now.
if meta.expected_len.is_none() {
meta.expected_len =
payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
meta.expected_len = payload_size(
self.buffer
.filled_get(meta.payload.start..meta.payload.end),
)?;
}
meta
@ -270,8 +275,7 @@ impl MessageDeframer {
// Write it into the buffer and create the metadata.
let expected_len = payload_size(payload)?;
let dst = &mut self.buf[..payload.len()];
dst.copy_from_slice(payload);
self.buffer.copy(payload, 0, quic);
self.joining_hs
.insert(HandshakePayloadMeta {
message: Range { start: 0, end },
@ -288,7 +292,7 @@ impl MessageDeframer {
Ok(match meta.expected_len {
Some(len) if len <= meta.payload.len() => HandshakePayloadState::Complete(len),
_ => match self.used > meta.message.end {
_ => match self.buffer.len() > meta.message.end {
true => HandshakePayloadState::Continue,
false => HandshakePayloadState::Blocked,
},
@ -298,7 +302,10 @@ impl MessageDeframer {
/// Read some bytes from `rd`, and add them to our internal buffer.
#[allow(clippy::comparison_chain)]
pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
if let Err(err) = self.prepare_read() {
if let Err(err) = self
.buffer
.prepare_read(self.joining_hs.is_some())
{
return Err(io::Error::new(io::ErrorKind::InvalidData, err));
}
@ -306,13 +313,33 @@ impl MessageDeframer {
// we get a message with a length field out of range here,
// we do a zero length read. That looks like an EOF to
// the next layer up, which is fine.
let new_bytes = rd.read(&mut self.buf[self.used..])?;
self.used += new_bytes;
let new_bytes = rd.read(self.buffer.unfilled())?;
self.buffer.advance(new_bytes);
Ok(new_bytes)
}
/// Returns true if we have messages for the caller
/// to process, either whole messages in our output
/// queue or partial messages in our buffer.
pub fn has_pending(&self) -> bool {
!self.buffer.is_empty()
}
}
#[derive(Default, Debug)]
struct DeframerVecBuffer {
/// Buffer of data read from the socket, in the process of being parsed into messages.
///
/// For buffer size management, checkout out the [`DeframerVecBuffer::prepare_read()`] method.
buf: Vec<u8>,
/// What size prefix of `buf` is used.
used: usize,
}
impl DeframerVecBuffer {
/// Resize the internal `buf` if necessary for reading more bytes.
fn prepare_read(&mut self) -> Result<(), &'static str> {
fn prepare_read(&mut self, is_joining_hs: bool) -> Result<(), &'static str> {
// We allow a maximum of 64k of buffered data for handshake messages only. Enforce this
// by varying the maximum allowed buffer size here based on whether a prefix of a
// handshake payload is currently being buffered. Given that the first read of such a
@ -320,9 +347,9 @@ impl MessageDeframer {
// larger buffer size. Once the large message and any following handshake messages in
// the same flight have been consumed, `pop()` will call `discard()` to reset `used`.
// At this point, the buffer resizing logic below should reduce the buffer size.
let allow_max = match self.joining_hs {
Some(_) => MAX_HANDSHAKE_SIZE as usize,
None => OpaqueMessage::MAX_WIRE_SIZE,
let allow_max = match is_joining_hs {
true => MAX_HANDSHAKE_SIZE as usize,
false => OpaqueMessage::MAX_WIRE_SIZE,
};
if self.used >= allow_max {
@ -345,11 +372,23 @@ impl MessageDeframer {
Ok(())
}
/// Returns true if we have messages for the caller
/// to process, either whole messages in our output
/// queue or partial messages in our buffer.
pub fn has_pending(&self) -> bool {
self.used > 0
/// Copies from the `src` buffer into this buffer at the requested index
///
/// If `quic` is true the data will be copied into the *un*filled section of the buffer
///
/// If `quic` is false the data will be copied into the filled section of the buffer
fn copy(&mut self, from: &[u8], at: usize, quic: bool) {
let buf = if quic {
self.unfilled()
} else {
self.filled_mut()
};
let len = from.len();
let into = &mut buf[at..at + len];
into.copy_from_slice(from);
if quic {
self.advance(len);
}
}
/// Discard `taken` bytes from the start of our buffer.
@ -376,6 +415,37 @@ impl MessageDeframer {
self.used = 0;
}
}
fn advance(&mut self, new_bytes: usize) {
self.used += new_bytes;
}
fn filled_mut(&mut self) -> &mut [u8] {
&mut self.buf[..self.used]
}
fn unfilled(&mut self) -> &mut [u8] {
&mut self.buf[self.used..]
}
fn filled_get<I>(&self, index: I) -> &I::Output
where
I: SliceIndex<[u8]>,
{
self.filled().get(index).unwrap()
}
fn filled(&self) -> &[u8] {
&self.buf[..self.used]
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn len(&self) -> usize {
self.used
}
}
enum HandshakePayloadState {
@ -712,14 +782,14 @@ mod tests {
}
fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
let before = d.used;
let before = d.buffer.len();
for i in 0..bytes.len() {
assert_len(1, input_bytes(d, &bytes[i..i + 1]));
assert!(d.has_pending());
}
assert_eq!(before + bytes.len(), d.used);
assert_eq!(before + bytes.len(), d.buffer.len());
}
fn assert_len(want: usize, got: io::Result<usize>) {