mirror of https://github.com/ctz/rustls
Optimise read_tls error path
This commit is contained in:
parent
6675b2f8eb
commit
b29daabd34
|
@ -25,9 +25,12 @@ pub struct MessageDeframer {
|
|||
/// the deframer cannot recover.
|
||||
pub desynced: bool,
|
||||
|
||||
/// A variable-size buffer containing the currently-
|
||||
/// accumulating TLS message.
|
||||
buf: Vec<u8>,
|
||||
/// A fixed-size buffer containing the currently-accumulating
|
||||
/// TLS message.
|
||||
buf: [u8; MAX_MESSAGE],
|
||||
|
||||
/// What size prefix of `buf` is used.
|
||||
used: usize,
|
||||
}
|
||||
|
||||
enum BufferContents {
|
||||
|
@ -47,7 +50,8 @@ impl MessageDeframer {
|
|||
MessageDeframer {
|
||||
frames: VecDeque::new(),
|
||||
desynced: false,
|
||||
buf: Vec::with_capacity(MAX_MESSAGE),
|
||||
buf: [0u8; MAX_MESSAGE],
|
||||
used: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,18 +63,10 @@ impl MessageDeframer {
|
|||
// we get a message with a length field out of range here,
|
||||
// we do a zero length read. That looks like an EOF to
|
||||
// the next layer up, which is fine.
|
||||
let used = self.buf.len();
|
||||
self.buf.resize(MAX_MESSAGE, 0u8);
|
||||
let rc = rd.read(&mut self.buf[used..MAX_MESSAGE]);
|
||||
debug_assert!(self.used <= MAX_MESSAGE);
|
||||
let new_bytes = rd.read(&mut self.buf[self.used..])?;
|
||||
|
||||
if rc.is_err() {
|
||||
// Discard indeterminate bytes.
|
||||
self.buf.truncate(used);
|
||||
return rc;
|
||||
}
|
||||
|
||||
let new_bytes = rc.unwrap();
|
||||
self.buf.truncate(used + new_bytes);
|
||||
self.used += new_bytes;
|
||||
|
||||
loop {
|
||||
match self.buf_contains_message() {
|
||||
|
@ -92,17 +88,17 @@ impl MessageDeframer {
|
|||
/// to process, either whole messages in our output
|
||||
/// queue or partial messages in our buffer.
|
||||
pub fn has_pending(&self) -> bool {
|
||||
!self.frames.is_empty() || !self.buf.is_empty()
|
||||
!self.frames.is_empty() || self.used > 0
|
||||
}
|
||||
|
||||
/// Does our `buf` contain a full message? It does if it is big enough to
|
||||
/// contain a header, and that header has a length which falls within `buf`.
|
||||
fn buf_contains_message(&self) -> BufferContents {
|
||||
if self.buf.len() < HEADER_SIZE {
|
||||
if self.used < HEADER_SIZE {
|
||||
return BufferContents::Partial;
|
||||
}
|
||||
|
||||
let len_maybe = Message::check_header(&self.buf);
|
||||
let len_maybe = Message::check_header(&self.buf[..self.used]);
|
||||
|
||||
// Header damaged.
|
||||
if len_maybe == None {
|
||||
|
@ -116,7 +112,7 @@ impl MessageDeframer {
|
|||
return BufferContents::Invalid;
|
||||
}
|
||||
|
||||
let full_message = self.buf.len() >= len + HEADER_SIZE;
|
||||
let full_message = self.used >= len + HEADER_SIZE;
|
||||
if full_message { BufferContents::Valid } else { BufferContents::Partial }
|
||||
}
|
||||
|
||||
|
@ -124,12 +120,38 @@ impl MessageDeframer {
|
|||
/// of our `frames` deque.
|
||||
fn deframe_one(&mut self) {
|
||||
let used = {
|
||||
let mut rd = codec::Reader::init(&self.buf);
|
||||
let mut rd = codec::Reader::init(&self.buf[..self.used]);
|
||||
let m = Message::read(&mut rd).unwrap();
|
||||
self.frames.push_back(m);
|
||||
rd.used()
|
||||
};
|
||||
self.buf = self.buf.split_off(used);
|
||||
self.buf_consume(used);
|
||||
}
|
||||
|
||||
fn buf_consume(&mut self, taken: usize) {
|
||||
if taken < self.used {
|
||||
/* Before:
|
||||
* +----------+----------+----------+
|
||||
* | taken | pending |xxxxxxxxxx|
|
||||
* +----------+----------+----------+
|
||||
* 0 ^ taken ^ self.used
|
||||
*
|
||||
* After:
|
||||
* +----------+----------+----------+
|
||||
* | pending |xxxxxxxxxxxxxxxxxxxxx|
|
||||
* +----------+----------+----------+
|
||||
* 0 ^ self.used
|
||||
*/
|
||||
let used_after = self.used - taken;
|
||||
|
||||
for i in 0..used_after {
|
||||
self.buf[i] = self.buf[i + taken];
|
||||
}
|
||||
|
||||
self.used = used_after;
|
||||
} else if taken == self.used {
|
||||
self.used = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,6 +198,14 @@ mod tests {
|
|||
d.read(&mut rd)
|
||||
}
|
||||
|
||||
fn input_bytes_concat(d: &mut MessageDeframer, bytes1: &[u8], bytes2: &[u8]) -> io::Result<usize> {
|
||||
let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
|
||||
bytes[..bytes1.len()].clone_from_slice(bytes1);
|
||||
bytes[bytes1.len()..].clone_from_slice(bytes2);
|
||||
let mut rd = ByteRead::new(&bytes);
|
||||
d.read(&mut rd)
|
||||
}
|
||||
|
||||
struct ErrorRead {
|
||||
error: Option<io::Error>,
|
||||
}
|
||||
|
@ -289,6 +319,30 @@ mod tests {
|
|||
assert_eq!(d.has_pending(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_in_one_read() {
|
||||
let mut d = MessageDeframer::new();
|
||||
assert_eq!(d.has_pending(), false);
|
||||
assert_len(FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
|
||||
input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE));
|
||||
assert_eq!(d.frames.len(), 2);
|
||||
pop_first(&mut d);
|
||||
pop_second(&mut d);
|
||||
assert_eq!(d.has_pending(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_in_one_read_shortest_first() {
|
||||
let mut d = MessageDeframer::new();
|
||||
assert_eq!(d.has_pending(), false);
|
||||
assert_len(FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
|
||||
input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE));
|
||||
assert_eq!(d.frames.len(), 2);
|
||||
pop_second(&mut d);
|
||||
pop_first(&mut d);
|
||||
assert_eq!(d.has_pending(), false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_incremental_with_nonfatal_read_error() {
|
||||
let mut d = MessageDeframer::new();
|
||||
|
|
Loading…
Reference in New Issue