mirror of https://github.com/ctz/rustls
hsjoiner: clarify the return type for take_message()
This commit is contained in:
parent
6087246dbf
commit
dbbb4eafc0
|
@ -16,7 +16,7 @@ fuzz_target!(|data: &[u8]| {
|
|||
|
||||
let mut jnr = hsjoiner::HandshakeJoiner::new();
|
||||
if jnr.want_message(&msg) {
|
||||
jnr.take_message(msg);
|
||||
let _ = jnr.take_message(msg);
|
||||
}
|
||||
|
||||
for msg in jnr.frames {
|
||||
|
|
|
@ -551,7 +551,7 @@ impl<Data> ConnectionCommon<Data> {
|
|||
if self
|
||||
.handshake_joiner
|
||||
.take_message(msg)
|
||||
.is_none()
|
||||
.is_err()
|
||||
{
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError);
|
||||
|
@ -625,7 +625,7 @@ impl<Data> ConnectionCommon<Data> {
|
|||
|
||||
self.handshake_joiner
|
||||
.take_message(msg)
|
||||
.ok_or_else(|| {
|
||||
.map_err(|_| {
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError);
|
||||
Error::CorruptMessagePayload(ContentType::Handshake)
|
||||
|
@ -806,7 +806,7 @@ impl<Data> ConnectionCommon<Data> {
|
|||
if self
|
||||
.handshake_joiner
|
||||
.take_message(msg)
|
||||
.is_none()
|
||||
.is_err()
|
||||
{
|
||||
self.common_state.quic.alert = Some(AlertDescription::DecodeError);
|
||||
return Err(Error::CorruptMessage);
|
||||
|
|
|
@ -63,13 +63,10 @@ impl HandshakeJoiner {
|
|||
}
|
||||
|
||||
/// Take the message, and join/split it as needed.
|
||||
/// Return the number of new messages added to the
|
||||
/// output deque as a result of this message.
|
||||
///
|
||||
/// Returns None if msg or a preceding message was corrupt.
|
||||
/// You cannot recover from this situation. Otherwise returns
|
||||
/// a count of how many messages we queued.
|
||||
pub fn take_message(&mut self, msg: PlainMessage) -> Option<usize> {
|
||||
/// Returns a `JoinerError` if `msg` or a preceding message was corrupt.
|
||||
/// You cannot recover from this situation.
|
||||
pub fn take_message(&mut self, msg: PlainMessage) -> Result<(), JoinerError> {
|
||||
// The vast majority of the time `self.buf` will be empty since most
|
||||
// handshake messages arrive in a single fragment. Avoid allocating and
|
||||
// copying in that common case.
|
||||
|
@ -80,22 +77,19 @@ impl HandshakeJoiner {
|
|||
.extend_from_slice(&msg.payload.0[..]);
|
||||
}
|
||||
|
||||
let mut count = 0;
|
||||
loop {
|
||||
match self.buf_contains_message() {
|
||||
BufferState::MessageTooLarge => return None,
|
||||
BufferState::MessageTooLarge => return Err(JoinerError::Decode),
|
||||
BufferState::NeedsMoreData => break,
|
||||
BufferState::OneMessage => {
|
||||
if !self.deframe_one(msg.version) {
|
||||
return None;
|
||||
return Err(JoinerError::Decode);
|
||||
}
|
||||
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(count)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Does our `buf` contain a full handshake payload? It does if it is big
|
||||
|
@ -142,6 +136,11 @@ impl HandshakeJoiner {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum JoinerError {
|
||||
Decode,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::HandshakeJoiner;
|
||||
|
@ -198,7 +197,7 @@ mod tests {
|
|||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert_eq!(hj.take_message(msg), Some(2));
|
||||
hj.take_message(msg).unwrap();
|
||||
assert!(hj.is_empty());
|
||||
|
||||
let expect = Message {
|
||||
|
@ -227,7 +226,7 @@ mod tests {
|
|||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert_eq!(hj.take_message(msg), None);
|
||||
hj.take_message(msg).unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -244,7 +243,7 @@ mod tests {
|
|||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert_eq!(hj.take_message(msg), Some(0));
|
||||
hj.take_message(msg).unwrap();
|
||||
assert!(!hj.is_empty());
|
||||
|
||||
// 11 more bytes.
|
||||
|
@ -255,7 +254,7 @@ mod tests {
|
|||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert_eq!(hj.take_message(msg), Some(0));
|
||||
hj.take_message(msg).unwrap();
|
||||
assert!(!hj.is_empty());
|
||||
|
||||
// Final 1 byte.
|
||||
|
@ -266,7 +265,7 @@ mod tests {
|
|||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert_eq!(hj.take_message(msg), Some(1));
|
||||
hj.take_message(msg).unwrap();
|
||||
assert!(hj.is_empty());
|
||||
|
||||
let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
|
||||
|
@ -292,7 +291,7 @@ mod tests {
|
|||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert_eq!(hj.take_message(msg), None);
|
||||
hj.take_message(msg).unwrap_err();
|
||||
assert!(!hj.is_empty());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue