mirror of https://github.com/ctz/rustls
extract out MessageDeframer buffer
This commit is contained in:
parent
838304ad5e
commit
20b8daecca
|
@ -1,5 +1,6 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
use core::slice::SliceIndex;
|
||||
use std::io;
|
||||
|
||||
use super::base::Payload;
|
||||
|
@ -22,16 +23,10 @@ pub struct MessageDeframer {
|
|||
/// the deframer cannot recover.
|
||||
last_error: Option<Error>,
|
||||
|
||||
/// Buffer of data read from the socket, in the process of being parsed into messages.
|
||||
///
|
||||
/// For buffer size management, checkout out the `read()` method.
|
||||
buf: Vec<u8>,
|
||||
|
||||
/// If we're in the middle of joining a handshake payload, this is the metadata.
|
||||
joining_hs: Option<HandshakePayloadMeta>,
|
||||
|
||||
/// What size prefix of `buf` is used.
|
||||
used: usize,
|
||||
buffer: DeframerVecBuffer,
|
||||
}
|
||||
|
||||
impl MessageDeframer {
|
||||
|
@ -47,7 +42,7 @@ impl MessageDeframer {
|
|||
) -> Result<Option<Deframed>, Error> {
|
||||
if let Some(last_err) = self.last_error.clone() {
|
||||
return Err(last_err);
|
||||
} else if self.used == 0 {
|
||||
} else if self.buffer.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
@ -72,7 +67,7 @@ impl MessageDeframer {
|
|||
// Does our `buf` contain a full message? It does if it is big enough to
|
||||
// contain a header, and that header has a length which falls within `buf`.
|
||||
// If so, deframe it and place the message onto the frames output queue.
|
||||
let mut rd = codec::Reader::init(&self.buf[start..self.used]);
|
||||
let mut rd = codec::Reader::init(self.buffer.filled_get(start..));
|
||||
let m = match OpaqueMessage::read(&mut rd) {
|
||||
Ok(m) => m,
|
||||
Err(msg_err) => {
|
||||
|
@ -116,7 +111,7 @@ impl MessageDeframer {
|
|||
};
|
||||
if self.joining_hs.is_none() && allowed_plaintext {
|
||||
// This is unencrypted. We check the contents later.
|
||||
self.discard(end);
|
||||
self.buffer.discard(end);
|
||||
return Ok(Some(Deframed {
|
||||
want_close_before_decrypt: false,
|
||||
aligned: true,
|
||||
|
@ -143,7 +138,7 @@ impl MessageDeframer {
|
|||
));
|
||||
}
|
||||
Ok(None) => {
|
||||
self.discard(end);
|
||||
self.buffer.discard(end);
|
||||
continue;
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
|
@ -160,7 +155,7 @@ impl MessageDeframer {
|
|||
// If it's not a handshake message, just return it -- no joining necessary.
|
||||
if msg.typ != ContentType::Handshake {
|
||||
let end = start + rd.used();
|
||||
self.discard(end);
|
||||
self.buffer.discard(end);
|
||||
return Ok(Some(Deframed {
|
||||
want_close_before_decrypt: false,
|
||||
aligned: true,
|
||||
|
@ -184,7 +179,10 @@ impl MessageDeframer {
|
|||
let message = PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: meta.version,
|
||||
payload: Payload::new(&self.buf[meta.payload.start..meta.payload.start + expected_len]),
|
||||
payload: Payload::new(
|
||||
self.buffer
|
||||
.filled_get(meta.payload.start..meta.payload.start + expected_len),
|
||||
),
|
||||
};
|
||||
|
||||
// But before we return, update the `joining_hs` state to skip past this payload.
|
||||
|
@ -193,13 +191,16 @@ impl MessageDeframer {
|
|||
// the payload start to point past the payload we're about to yield, and update the
|
||||
// `expected_len` to match the state of that remaining payload.
|
||||
meta.payload.start += expected_len;
|
||||
meta.expected_len = payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
|
||||
meta.expected_len = payload_size(
|
||||
self.buffer
|
||||
.filled_get(meta.payload.start..meta.payload.end),
|
||||
)?;
|
||||
} else {
|
||||
// Otherwise, we've yielded the last handshake payload in the buffer, so we can
|
||||
// discard all of the bytes that we're previously buffered as handshake data.
|
||||
let end = meta.message.end;
|
||||
self.joining_hs = None;
|
||||
self.discard(end);
|
||||
self.buffer.discard(end);
|
||||
}
|
||||
|
||||
Ok(Some(Deframed {
|
||||
|
@ -221,17 +222,19 @@ impl MessageDeframer {
|
|||
|
||||
/// Allow pushing handshake messages directly into the buffer.
|
||||
pub(crate) fn push(&mut self, version: ProtocolVersion, payload: &[u8]) -> Result<(), Error> {
|
||||
if self.used > 0 && self.joining_hs.is_none() {
|
||||
if !self.buffer.is_empty() && self.joining_hs.is_none() {
|
||||
return Err(Error::General(
|
||||
"cannot push QUIC messages into unrelated connection".into(),
|
||||
));
|
||||
} else if let Err(err) = self.prepare_read() {
|
||||
} else if let Err(err) = self
|
||||
.buffer
|
||||
.prepare_read(self.joining_hs.is_some())
|
||||
{
|
||||
return Err(Error::General(err.into()));
|
||||
}
|
||||
|
||||
let end = self.used + payload.len();
|
||||
let end = self.buffer.len() + payload.len();
|
||||
self.append_hs(version, payload, end, true)?;
|
||||
self.used = end;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -252,15 +255,17 @@ impl MessageDeframer {
|
|||
// We're joining a handshake message to the previous one here.
|
||||
// Write it into the buffer and update the metadata.
|
||||
|
||||
let dst = &mut self.buf[meta.payload.end..meta.payload.end + payload.len()];
|
||||
dst.copy_from_slice(payload);
|
||||
self.buffer
|
||||
.copy(payload, meta.payload.end, quic);
|
||||
meta.message.end = end;
|
||||
meta.payload.end += payload.len();
|
||||
|
||||
// If we haven't parsed the payload size yet, try to do so now.
|
||||
if meta.expected_len.is_none() {
|
||||
meta.expected_len =
|
||||
payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
|
||||
meta.expected_len = payload_size(
|
||||
self.buffer
|
||||
.filled_get(meta.payload.start..meta.payload.end),
|
||||
)?;
|
||||
}
|
||||
|
||||
meta
|
||||
|
@ -270,8 +275,7 @@ impl MessageDeframer {
|
|||
// Write it into the buffer and create the metadata.
|
||||
|
||||
let expected_len = payload_size(payload)?;
|
||||
let dst = &mut self.buf[..payload.len()];
|
||||
dst.copy_from_slice(payload);
|
||||
self.buffer.copy(payload, 0, quic);
|
||||
self.joining_hs
|
||||
.insert(HandshakePayloadMeta {
|
||||
message: Range { start: 0, end },
|
||||
|
@ -288,7 +292,7 @@ impl MessageDeframer {
|
|||
|
||||
Ok(match meta.expected_len {
|
||||
Some(len) if len <= meta.payload.len() => HandshakePayloadState::Complete(len),
|
||||
_ => match self.used > meta.message.end {
|
||||
_ => match self.buffer.len() > meta.message.end {
|
||||
true => HandshakePayloadState::Continue,
|
||||
false => HandshakePayloadState::Blocked,
|
||||
},
|
||||
|
@ -298,7 +302,10 @@ impl MessageDeframer {
|
|||
/// Read some bytes from `rd`, and add them to our internal buffer.
|
||||
#[allow(clippy::comparison_chain)]
|
||||
pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
|
||||
if let Err(err) = self.prepare_read() {
|
||||
if let Err(err) = self
|
||||
.buffer
|
||||
.prepare_read(self.joining_hs.is_some())
|
||||
{
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, err));
|
||||
}
|
||||
|
||||
|
@ -306,13 +313,33 @@ impl MessageDeframer {
|
|||
// we get a message with a length field out of range here,
|
||||
// we do a zero length read. That looks like an EOF to
|
||||
// the next layer up, which is fine.
|
||||
let new_bytes = rd.read(&mut self.buf[self.used..])?;
|
||||
self.used += new_bytes;
|
||||
let new_bytes = rd.read(self.buffer.unfilled())?;
|
||||
self.buffer.advance(new_bytes);
|
||||
Ok(new_bytes)
|
||||
}
|
||||
|
||||
/// Returns true if we have messages for the caller
|
||||
/// to process, either whole messages in our output
|
||||
/// queue or partial messages in our buffer.
|
||||
pub fn has_pending(&self) -> bool {
|
||||
!self.buffer.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct DeframerVecBuffer {
|
||||
/// Buffer of data read from the socket, in the process of being parsed into messages.
|
||||
///
|
||||
/// For buffer size management, checkout out the [`DeframerVecBuffer::prepare_read()`] method.
|
||||
buf: Vec<u8>,
|
||||
|
||||
/// What size prefix of `buf` is used.
|
||||
used: usize,
|
||||
}
|
||||
|
||||
impl DeframerVecBuffer {
|
||||
/// Resize the internal `buf` if necessary for reading more bytes.
|
||||
fn prepare_read(&mut self) -> Result<(), &'static str> {
|
||||
fn prepare_read(&mut self, is_joining_hs: bool) -> Result<(), &'static str> {
|
||||
// We allow a maximum of 64k of buffered data for handshake messages only. Enforce this
|
||||
// by varying the maximum allowed buffer size here based on whether a prefix of a
|
||||
// handshake payload is currently being buffered. Given that the first read of such a
|
||||
|
@ -320,9 +347,9 @@ impl MessageDeframer {
|
|||
// larger buffer size. Once the large message and any following handshake messages in
|
||||
// the same flight have been consumed, `pop()` will call `discard()` to reset `used`.
|
||||
// At this point, the buffer resizing logic below should reduce the buffer size.
|
||||
let allow_max = match self.joining_hs {
|
||||
Some(_) => MAX_HANDSHAKE_SIZE as usize,
|
||||
None => OpaqueMessage::MAX_WIRE_SIZE,
|
||||
let allow_max = match is_joining_hs {
|
||||
true => MAX_HANDSHAKE_SIZE as usize,
|
||||
false => OpaqueMessage::MAX_WIRE_SIZE,
|
||||
};
|
||||
|
||||
if self.used >= allow_max {
|
||||
|
@ -345,11 +372,23 @@ impl MessageDeframer {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns true if we have messages for the caller
|
||||
/// to process, either whole messages in our output
|
||||
/// queue or partial messages in our buffer.
|
||||
pub fn has_pending(&self) -> bool {
|
||||
self.used > 0
|
||||
/// Copies from the `src` buffer into this buffer at the requested index
|
||||
///
|
||||
/// If `quic` is true the data will be copied into the *un*filled section of the buffer
|
||||
///
|
||||
/// If `quic` is false the data will be copied into the filled section of the buffer
|
||||
fn copy(&mut self, from: &[u8], at: usize, quic: bool) {
|
||||
let buf = if quic {
|
||||
self.unfilled()
|
||||
} else {
|
||||
self.filled_mut()
|
||||
};
|
||||
let len = from.len();
|
||||
let into = &mut buf[at..at + len];
|
||||
into.copy_from_slice(from);
|
||||
if quic {
|
||||
self.advance(len);
|
||||
}
|
||||
}
|
||||
|
||||
/// Discard `taken` bytes from the start of our buffer.
|
||||
|
@ -376,6 +415,37 @@ impl MessageDeframer {
|
|||
self.used = 0;
|
||||
}
|
||||
}
|
||||
|
||||
fn advance(&mut self, new_bytes: usize) {
|
||||
self.used += new_bytes;
|
||||
}
|
||||
|
||||
fn filled_mut(&mut self) -> &mut [u8] {
|
||||
&mut self.buf[..self.used]
|
||||
}
|
||||
|
||||
fn unfilled(&mut self) -> &mut [u8] {
|
||||
&mut self.buf[self.used..]
|
||||
}
|
||||
|
||||
fn filled_get<I>(&self, index: I) -> &I::Output
|
||||
where
|
||||
I: SliceIndex<[u8]>,
|
||||
{
|
||||
self.filled().get(index).unwrap()
|
||||
}
|
||||
|
||||
fn filled(&self) -> &[u8] {
|
||||
&self.buf[..self.used]
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.used
|
||||
}
|
||||
}
|
||||
|
||||
enum HandshakePayloadState {
|
||||
|
@ -712,14 +782,14 @@ mod tests {
|
|||
}
|
||||
|
||||
fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
|
||||
let before = d.used;
|
||||
let before = d.buffer.len();
|
||||
|
||||
for i in 0..bytes.len() {
|
||||
assert_len(1, input_bytes(d, &bytes[i..i + 1]));
|
||||
assert!(d.has_pending());
|
||||
}
|
||||
|
||||
assert_eq!(before + bytes.len(), d.used);
|
||||
assert_eq!(before + bytes.len(), d.buffer.len());
|
||||
}
|
||||
|
||||
fn assert_len(want: usize, got: io::Result<usize>) {
|
||||
|
|
Loading…
Reference in New Issue