hsjoiner: clarify the return type for take_message()

This commit is contained in:
Dirkjan Ochtman 2021-05-22 14:01:48 +02:00
parent 6087246dbf
commit dbbb4eafc0
3 changed files with 21 additions and 22 deletions

View File

@ -16,7 +16,7 @@ fuzz_target!(|data: &[u8]| {
let mut jnr = hsjoiner::HandshakeJoiner::new();
if jnr.want_message(&msg) {
jnr.take_message(msg);
let _ = jnr.take_message(msg);
}
for msg in jnr.frames {

View File

@ -551,7 +551,7 @@ impl<Data> ConnectionCommon<Data> {
if self
.handshake_joiner
.take_message(msg)
.is_none()
.is_err()
{
self.common_state
.send_fatal_alert(AlertDescription::DecodeError);
@ -625,7 +625,7 @@ impl<Data> ConnectionCommon<Data> {
self.handshake_joiner
.take_message(msg)
.ok_or_else(|| {
.map_err(|_| {
self.common_state
.send_fatal_alert(AlertDescription::DecodeError);
Error::CorruptMessagePayload(ContentType::Handshake)
@ -806,7 +806,7 @@ impl<Data> ConnectionCommon<Data> {
if self
.handshake_joiner
.take_message(msg)
.is_none()
.is_err()
{
self.common_state.quic.alert = Some(AlertDescription::DecodeError);
return Err(Error::CorruptMessage);

View File

@ -63,13 +63,10 @@ impl HandshakeJoiner {
}
/// Take the message, and join/split it as needed.
/// Return the number of new messages added to the
/// output deque as a result of this message.
///
/// Returns None if msg or a preceding message was corrupt.
/// You cannot recover from this situation. Otherwise returns
/// a count of how many messages we queued.
pub fn take_message(&mut self, msg: PlainMessage) -> Option<usize> {
/// Returns a `JoinerError` if `msg` or a preceding message was corrupt.
/// You cannot recover from this situation.
pub fn take_message(&mut self, msg: PlainMessage) -> Result<(), JoinerError> {
// The vast majority of the time `self.buf` will be empty since most
// handshake messages arrive in a single fragment. Avoid allocating and
// copying in that common case.
@ -80,22 +77,19 @@ impl HandshakeJoiner {
.extend_from_slice(&msg.payload.0[..]);
}
let mut count = 0;
loop {
match self.buf_contains_message() {
BufferState::MessageTooLarge => return None,
BufferState::MessageTooLarge => return Err(JoinerError::Decode),
BufferState::NeedsMoreData => break,
BufferState::OneMessage => {
if !self.deframe_one(msg.version) {
return None;
return Err(JoinerError::Decode);
}
count += 1;
}
}
}
Some(count)
Ok(())
}
/// Does our `buf` contain a full handshake payload? It does if it is big
@ -142,6 +136,11 @@ impl HandshakeJoiner {
}
}
#[derive(Debug)]
pub enum JoinerError {
Decode,
}
#[cfg(test)]
mod tests {
use super::HandshakeJoiner;
@ -198,7 +197,7 @@ mod tests {
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), Some(2));
hj.take_message(msg).unwrap();
assert!(hj.is_empty());
let expect = Message {
@ -227,7 +226,7 @@ mod tests {
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), None);
hj.take_message(msg).unwrap_err();
}
#[test]
@ -244,7 +243,7 @@ mod tests {
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), Some(0));
hj.take_message(msg).unwrap();
assert!(!hj.is_empty());
// 11 more bytes.
@ -255,7 +254,7 @@ mod tests {
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), Some(0));
hj.take_message(msg).unwrap();
assert!(!hj.is_empty());
// Final 1 byte.
@ -266,7 +265,7 @@ mod tests {
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), Some(1));
hj.take_message(msg).unwrap();
assert!(hj.is_empty());
let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
@ -292,7 +291,7 @@ mod tests {
};
assert!(hj.want_message(&msg));
assert_eq!(hj.take_message(msg), None);
hj.take_message(msg).unwrap_err();
assert!(!hj.is_empty());
}
}