hsjoiner: simplify awkward message type checking API

This commit is contained in:
Dirkjan Ochtman 2022-12-22 12:15:22 +01:00
parent 279f88fb26
commit 9e657be6b9
3 changed files with 70 additions and 81 deletions

View File

@ -15,8 +15,9 @@ fuzz_target!(|data: &[u8]| {
};
let mut jnr = hsjoiner::HandshakeJoiner::new();
if jnr.want_message(&msg) {
let _ = jnr.take_message(msg);
match jnr.push(msg) {
Ok(_) => {},
Err(_) => return,
}
while let Ok(Some(msg)) = jnr.pop() {

View File

@ -10,7 +10,7 @@ use crate::msgs::enums::HandshakeType;
use crate::msgs::enums::{AlertDescription, AlertLevel, ContentType};
use crate::msgs::fragmenter::MessageFragmenter;
use crate::msgs::handshake::Random;
use crate::msgs::hsjoiner::HandshakeJoiner;
use crate::msgs::hsjoiner::{HandshakeJoiner, JoinerError};
use crate::msgs::message::{
BorrowedPlainMessage, Message, MessagePayload, OpaqueMessage, PlainMessage,
};
@ -544,12 +544,8 @@ impl<Data> ConnectionCommon<Data> {
};
let msg = msg.into_plain_message();
if !self.handshake_joiner.want_message(&msg) {
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
self.handshake_joiner
.take_message(msg)
.push(msg)
.and_then(|aligned| {
self.common_state.aligned_handshake = aligned;
self.handshake_joiner.pop()
@ -614,25 +610,28 @@ impl<Data> ConnectionCommon<Data> {
false => msg.into_plain_message(),
};
// For handshake messages, we need to join them before parsing
// and processing.
if self.handshake_joiner.want_message(&msg) {
// First decryptable handshake message concludes trial decryption
self.common_state
.record_layer
.finish_trial_decryption();
// For handshake messages, we need to join them before parsing and processing.
let msg = match self.handshake_joiner.push(msg) {
// Handshake message, we handle these in another method.
Ok(aligned) => {
self.common_state.aligned_handshake = aligned;
self.common_state.aligned_handshake = self
.handshake_joiner
.take_message(msg)
.map_err(|_| {
self.common_state
.send_fatal_alert(AlertDescription::DecodeError);
Error::CorruptMessagePayload(ContentType::Handshake)
})?;
// First decryptable handshake message concludes trial decryption
self.common_state
.record_layer
.finish_trial_decryption();
return self.process_new_handshake_messages(state);
}
return self.process_new_handshake_messages(state);
}
// Not a handshake message, continue to handle it here.
Err(JoinerError::Unwanted(msg)) => msg,
// Decoding the handshake message failed, yield an error.
Err(JoinerError::Decode) => {
self.common_state
.send_fatal_alert(AlertDescription::DecodeError);
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
}
};
// Now we can fully parse the message payload.
let msg = Message::try_from(msg)?;
@ -819,11 +818,7 @@ impl<Data> ConnectionCommon<Data> {
payload: Payload::new(plaintext.to_vec()),
};
if self
.handshake_joiner
.take_message(msg)
.is_err()
{
if self.handshake_joiner.push(msg).is_err() {
self.common_state.quic.alert = Some(AlertDescription::DecodeError);
return Err(Error::CorruptMessage);
}

View File

@ -42,17 +42,18 @@ impl HandshakeJoiner {
}
}
/// Do we want to process this message?
pub fn want_message(&self, msg: &PlainMessage) -> bool {
msg.typ == ContentType::Handshake
}
/// Take the message, and join/split it as needed.
///
/// Returns an `Err` if a received payload has an advertised size larger than we accept,
/// or a `bool` to indicate whether the handshake is "aligned": if the buffer currently
/// Returns `Err(JoinerError::Unwanted(msg))` if `msg`'s type is not `ContentType::Handshake` or
/// `JoinerError::Decode` if a received payload has an advertised size larger than we accept.
///
/// Otherwise, yields a `bool` to indicate whether the handshake is "aligned": if the buffer currently
/// only contains complete payloads (that is, no incomplete message in the suffix).
pub fn take_message(&mut self, msg: PlainMessage) -> Result<bool, JoinerError> {
pub fn push(&mut self, msg: PlainMessage) -> Result<bool, JoinerError> {
if msg.typ != ContentType::Handshake {
return Err(JoinerError::Unwanted(msg));
}
// The vast majority of the time `self.buf` will be empty since most
// handshake messages arrive in a single fragment. Avoid allocating and
// copying in that common case.
@ -137,6 +138,7 @@ fn payload_size(buf: &[u8]) -> Result<Option<usize>, JoinerError> {
#[derive(Debug)]
pub enum JoinerError {
Unwanted(PlainMessage),
Decode,
}
@ -152,12 +154,11 @@ mod tests {
#[test]
fn want() {
let hj = HandshakeJoiner::new();
let mut hj = HandshakeJoiner::new();
let wanted = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"hello world".to_vec()),
payload: Payload::new(b"\x00\x00\x00\x00".to_vec()),
};
let unwanted = PlainMessage {
@ -166,8 +167,8 @@ mod tests {
payload: Payload::new(b"ponytown".to_vec()),
};
assert!(hj.want_message(&wanted));
assert!(!hj.want_message(&unwanted));
hj.push(wanted).unwrap();
hj.push(unwanted).unwrap_err();
}
fn pop_eq(expect: &PlainMessage, hj: &mut HandshakeJoiner) {
@ -188,14 +189,13 @@ mod tests {
let mut hj = HandshakeJoiner::new();
// two HelloRequests
let msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
};
assert!(hj.want_message(&msg));
assert!(hj.take_message(msg).unwrap());
assert!(hj
.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
})
.unwrap());
let expect = Message {
version: ProtocolVersion::TLSv1_2,
@ -216,14 +216,13 @@ mod tests {
let mut hj = HandshakeJoiner::new();
// short ClientHello
let msg = PlainMessage {
hj.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x01\x00\x00\x02\xff\xff".to_vec()),
};
})
.unwrap();
assert!(hj.want_message(&msg));
hj.take_message(msg).unwrap();
hj.pop().unwrap_err();
}
@ -233,34 +232,30 @@ mod tests {
let mut hj = HandshakeJoiner::new();
// Introduce Finished of 16 bytes, providing 4.
let mut msg = PlainMessage {
hj.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()),
};
assert!(hj.want_message(&msg));
hj.take_message(msg).unwrap();
})
.unwrap();
// 11 more bytes.
msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()),
};
assert!(hj.want_message(&msg));
assert!(!hj.take_message(msg).unwrap());
assert!(!hj
.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()),
})
.unwrap());
// Final 1 byte.
msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x0f".to_vec()),
};
assert!(hj.want_message(&msg));
assert!(hj.take_message(msg).unwrap());
assert!(hj
.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x0f".to_vec()),
})
.unwrap());
let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
let expect = Message {
@ -278,13 +273,11 @@ mod tests {
#[test]
fn test_rejects_giant_certs() {
let mut hj = HandshakeJoiner::new();
let msg = PlainMessage {
hj.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x0b\x01\x00\x04\x01\x00\x01\x00\xff\xfe".to_vec()),
};
assert!(hj.want_message(&msg));
hj.take_message(msg).unwrap_err();
})
.unwrap_err();
}
}