mirror of https://github.com/ctz/rustls
hsjoiner: simplify awkward message type checking API
This commit is contained in:
parent
279f88fb26
commit
9e657be6b9
|
@ -15,8 +15,9 @@ fuzz_target!(|data: &[u8]| {
|
|||
};
|
||||
|
||||
let mut jnr = hsjoiner::HandshakeJoiner::new();
|
||||
if jnr.want_message(&msg) {
|
||||
let _ = jnr.take_message(msg);
|
||||
match jnr.push(msg) {
|
||||
Ok(_) => {},
|
||||
Err(_) => return,
|
||||
}
|
||||
|
||||
while let Ok(Some(msg)) = jnr.pop() {
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::msgs::enums::HandshakeType;
|
|||
use crate::msgs::enums::{AlertDescription, AlertLevel, ContentType};
|
||||
use crate::msgs::fragmenter::MessageFragmenter;
|
||||
use crate::msgs::handshake::Random;
|
||||
use crate::msgs::hsjoiner::HandshakeJoiner;
|
||||
use crate::msgs::hsjoiner::{HandshakeJoiner, JoinerError};
|
||||
use crate::msgs::message::{
|
||||
BorrowedPlainMessage, Message, MessagePayload, OpaqueMessage, PlainMessage,
|
||||
};
|
||||
|
@ -544,12 +544,8 @@ impl<Data> ConnectionCommon<Data> {
|
|||
};
|
||||
|
||||
let msg = msg.into_plain_message();
|
||||
if !self.handshake_joiner.want_message(&msg) {
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
|
||||
self.handshake_joiner
|
||||
.take_message(msg)
|
||||
.push(msg)
|
||||
.and_then(|aligned| {
|
||||
self.common_state.aligned_handshake = aligned;
|
||||
self.handshake_joiner.pop()
|
||||
|
@ -614,25 +610,28 @@ impl<Data> ConnectionCommon<Data> {
|
|||
false => msg.into_plain_message(),
|
||||
};
|
||||
|
||||
// For handshake messages, we need to join them before parsing
|
||||
// and processing.
|
||||
if self.handshake_joiner.want_message(&msg) {
|
||||
// First decryptable handshake message concludes trial decryption
|
||||
self.common_state
|
||||
.record_layer
|
||||
.finish_trial_decryption();
|
||||
// For handshake messages, we need to join them before parsing and processing.
|
||||
let msg = match self.handshake_joiner.push(msg) {
|
||||
// Handshake message, we handle these in another method.
|
||||
Ok(aligned) => {
|
||||
self.common_state.aligned_handshake = aligned;
|
||||
|
||||
self.common_state.aligned_handshake = self
|
||||
.handshake_joiner
|
||||
.take_message(msg)
|
||||
.map_err(|_| {
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError);
|
||||
Error::CorruptMessagePayload(ContentType::Handshake)
|
||||
})?;
|
||||
// First decryptable handshake message concludes trial decryption
|
||||
self.common_state
|
||||
.record_layer
|
||||
.finish_trial_decryption();
|
||||
|
||||
return self.process_new_handshake_messages(state);
|
||||
}
|
||||
return self.process_new_handshake_messages(state);
|
||||
}
|
||||
// Not a handshake message, continue to handle it here.
|
||||
Err(JoinerError::Unwanted(msg)) => msg,
|
||||
// Decoding the handshake message failed, yield an error.
|
||||
Err(JoinerError::Decode) => {
|
||||
self.common_state
|
||||
.send_fatal_alert(AlertDescription::DecodeError);
|
||||
return Err(Error::CorruptMessagePayload(ContentType::Handshake));
|
||||
}
|
||||
};
|
||||
|
||||
// Now we can fully parse the message payload.
|
||||
let msg = Message::try_from(msg)?;
|
||||
|
@ -819,11 +818,7 @@ impl<Data> ConnectionCommon<Data> {
|
|||
payload: Payload::new(plaintext.to_vec()),
|
||||
};
|
||||
|
||||
if self
|
||||
.handshake_joiner
|
||||
.take_message(msg)
|
||||
.is_err()
|
||||
{
|
||||
if self.handshake_joiner.push(msg).is_err() {
|
||||
self.common_state.quic.alert = Some(AlertDescription::DecodeError);
|
||||
return Err(Error::CorruptMessage);
|
||||
}
|
||||
|
|
|
@ -42,17 +42,18 @@ impl HandshakeJoiner {
|
|||
}
|
||||
}
|
||||
|
||||
/// Do we want to process this message?
|
||||
pub fn want_message(&self, msg: &PlainMessage) -> bool {
|
||||
msg.typ == ContentType::Handshake
|
||||
}
|
||||
|
||||
/// Take the message, and join/split it as needed.
|
||||
///
|
||||
/// Returns an `Err` if a received payload has an advertised size larger than we accept,
|
||||
/// or a `bool` to indicate whether the handshake is "aligned": if the buffer currently
|
||||
/// Returns `Err(JoinerError::Unwanted(msg))` if `msg`'s type is not `ContentType::Handshake` or
|
||||
/// `JoinerError::Decode` if a received payload has an advertised size larger than we accept.
|
||||
///
|
||||
/// Otherwise, yields a `bool` to indicate whether the handshake is "aligned": if the buffer currently
|
||||
/// only contains complete payloads (that is, no incomplete message in the suffix).
|
||||
pub fn take_message(&mut self, msg: PlainMessage) -> Result<bool, JoinerError> {
|
||||
pub fn push(&mut self, msg: PlainMessage) -> Result<bool, JoinerError> {
|
||||
if msg.typ != ContentType::Handshake {
|
||||
return Err(JoinerError::Unwanted(msg));
|
||||
}
|
||||
|
||||
// The vast majority of the time `self.buf` will be empty since most
|
||||
// handshake messages arrive in a single fragment. Avoid allocating and
|
||||
// copying in that common case.
|
||||
|
@ -137,6 +138,7 @@ fn payload_size(buf: &[u8]) -> Result<Option<usize>, JoinerError> {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub enum JoinerError {
|
||||
Unwanted(PlainMessage),
|
||||
Decode,
|
||||
}
|
||||
|
||||
|
@ -152,12 +154,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn want() {
|
||||
let hj = HandshakeJoiner::new();
|
||||
|
||||
let mut hj = HandshakeJoiner::new();
|
||||
let wanted = PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"hello world".to_vec()),
|
||||
payload: Payload::new(b"\x00\x00\x00\x00".to_vec()),
|
||||
};
|
||||
|
||||
let unwanted = PlainMessage {
|
||||
|
@ -166,8 +167,8 @@ mod tests {
|
|||
payload: Payload::new(b"ponytown".to_vec()),
|
||||
};
|
||||
|
||||
assert!(hj.want_message(&wanted));
|
||||
assert!(!hj.want_message(&unwanted));
|
||||
hj.push(wanted).unwrap();
|
||||
hj.push(unwanted).unwrap_err();
|
||||
}
|
||||
|
||||
fn pop_eq(expect: &PlainMessage, hj: &mut HandshakeJoiner) {
|
||||
|
@ -188,14 +189,13 @@ mod tests {
|
|||
let mut hj = HandshakeJoiner::new();
|
||||
|
||||
// two HelloRequests
|
||||
let msg = PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
|
||||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert!(hj.take_message(msg).unwrap());
|
||||
assert!(hj
|
||||
.push(PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
|
||||
})
|
||||
.unwrap());
|
||||
|
||||
let expect = Message {
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
|
@ -216,14 +216,13 @@ mod tests {
|
|||
let mut hj = HandshakeJoiner::new();
|
||||
|
||||
// short ClientHello
|
||||
let msg = PlainMessage {
|
||||
hj.push(PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x01\x00\x00\x02\xff\xff".to_vec()),
|
||||
};
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
hj.take_message(msg).unwrap();
|
||||
hj.pop().unwrap_err();
|
||||
}
|
||||
|
||||
|
@ -233,34 +232,30 @@ mod tests {
|
|||
let mut hj = HandshakeJoiner::new();
|
||||
|
||||
// Introduce Finished of 16 bytes, providing 4.
|
||||
let mut msg = PlainMessage {
|
||||
hj.push(PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()),
|
||||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
hj.take_message(msg).unwrap();
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// 11 more bytes.
|
||||
msg = PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()),
|
||||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert!(!hj.take_message(msg).unwrap());
|
||||
assert!(!hj
|
||||
.push(PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()),
|
||||
})
|
||||
.unwrap());
|
||||
|
||||
// Final 1 byte.
|
||||
msg = PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x0f".to_vec()),
|
||||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
assert!(hj.take_message(msg).unwrap());
|
||||
assert!(hj
|
||||
.push(PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x0f".to_vec()),
|
||||
})
|
||||
.unwrap());
|
||||
|
||||
let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
|
||||
let expect = Message {
|
||||
|
@ -278,13 +273,11 @@ mod tests {
|
|||
#[test]
|
||||
fn test_rejects_giant_certs() {
|
||||
let mut hj = HandshakeJoiner::new();
|
||||
let msg = PlainMessage {
|
||||
hj.push(PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload::new(b"\x0b\x01\x00\x04\x01\x00\x01\x00\xff\xfe".to_vec()),
|
||||
};
|
||||
|
||||
assert!(hj.want_message(&msg));
|
||||
hj.take_message(msg).unwrap_err();
|
||||
})
|
||||
.unwrap_err();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue