mirror of https://github.com/ctz/rustls
245 lines
7.7 KiB
Rust
245 lines
7.7 KiB
Rust
|
|
use std::collections::VecDeque;
|
|
|
|
use crate::msgs::codec;
|
|
use crate::msgs::message::{Message, MessagePayload};
|
|
use crate::msgs::enums::{ContentType, ProtocolVersion};
|
|
use crate::msgs::handshake::HandshakeMessagePayload;
|
|
|
|
const HEADER_SIZE: usize = 1 + 3;
|
|
|
|
/// This works to reconstruct TLS handshake messages
|
|
/// from individual TLS messages. It's guaranteed that
|
|
/// TLS messages output from this layer contain precisely
|
|
/// one handshake payload.
|
|
pub struct HandshakeJoiner {
|
|
/// Completed handshake frames for output.
|
|
pub frames: VecDeque<Message>,
|
|
|
|
/// The message payload we're currently accumulating.
|
|
buf: Vec<u8>,
|
|
}
|
|
|
|
impl Default for HandshakeJoiner {
|
|
fn default() -> Self { Self::new() }
|
|
}
|
|
|
|
impl HandshakeJoiner {
|
|
/// Make a new HandshakeJoiner.
|
|
pub fn new() -> HandshakeJoiner {
|
|
HandshakeJoiner {
|
|
frames: VecDeque::new(),
|
|
buf: Vec::new(),
|
|
}
|
|
}
|
|
|
|
/// Do we want to process this message?
|
|
pub fn want_message(&self, msg: &Message) -> bool {
|
|
msg.is_content_type(ContentType::Handshake)
|
|
}
|
|
|
|
/// Do we have any buffered data?
|
|
pub fn is_empty(&self) -> bool {
|
|
self.buf.is_empty()
|
|
}
|
|
|
|
/// Take the message, and join/split it as needed.
|
|
/// Return the number of new messages added to the
|
|
/// output deque as a result of this message.
|
|
///
|
|
/// Returns None if msg or a preceding message was corrupt.
|
|
/// You cannot recover from this situation. Otherwise returns
|
|
/// a count of how many messages we queued.
|
|
pub fn take_message(&mut self, mut msg: Message) -> Option<usize> {
|
|
// Input must be opaque, otherwise we might have already
|
|
// lost information!
|
|
let payload = msg.take_opaque_payload().unwrap();
|
|
|
|
self.buf.extend_from_slice(&payload.0[..]);
|
|
|
|
let mut count = 0;
|
|
while self.buf_contains_message() {
|
|
if !self.deframe_one(msg.version) {
|
|
return None;
|
|
}
|
|
|
|
count += 1;
|
|
}
|
|
|
|
Some(count)
|
|
}
|
|
|
|
/// Does our `buf` contain a full handshake payload? It does if it is big
|
|
/// enough to contain a header, and that header has a length which falls
|
|
/// within `buf`.
|
|
fn buf_contains_message(&self) -> bool {
|
|
self.buf.len() >= HEADER_SIZE &&
|
|
self.buf.len() >= (codec::u24::decode(&self.buf[1..4]).unwrap().0 as usize) + HEADER_SIZE
|
|
}
|
|
|
|
/// Take a TLS handshake payload off the front of `buf`, and put it onto
|
|
/// the back of our `frames` deque inside a normal `Message`.
|
|
///
|
|
/// Returns false if the stream is desynchronised beyond repair.
|
|
fn deframe_one(&mut self, version: ProtocolVersion) -> bool {
|
|
let used = {
|
|
let mut rd = codec::Reader::init(&self.buf);
|
|
let payload = HandshakeMessagePayload::read_version(&mut rd, version);
|
|
|
|
if payload.is_none() {
|
|
return false;
|
|
}
|
|
|
|
let m = Message {
|
|
typ: ContentType::Handshake,
|
|
version,
|
|
payload: MessagePayload::Handshake(payload.unwrap()),
|
|
};
|
|
|
|
self.frames.push_back(m);
|
|
rd.used()
|
|
};
|
|
self.buf = self.buf.split_off(used);
|
|
true
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::HandshakeJoiner;
|
|
use crate::msgs::enums::{ProtocolVersion, ContentType, HandshakeType};
|
|
use crate::msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
|
|
use crate::msgs::message::{Message, MessagePayload};
|
|
use crate::msgs::base::Payload;
|
|
|
|
#[test]
|
|
fn want() {
|
|
let hj = HandshakeJoiner::new();
|
|
assert_eq!(hj.is_empty(), true);
|
|
|
|
let wanted = Message {
|
|
typ: ContentType::Handshake,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::new_opaque(b"hello world".to_vec()),
|
|
};
|
|
|
|
let unwanted = Message {
|
|
typ: ContentType::Alert,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::new_opaque(b"ponytown".to_vec()),
|
|
};
|
|
|
|
assert_eq!(hj.want_message(&wanted), true);
|
|
assert_eq!(hj.want_message(&unwanted), false);
|
|
}
|
|
|
|
fn pop_eq(expect: &Message, hj: &mut HandshakeJoiner) {
|
|
let got = hj.frames.pop_front().unwrap();
|
|
assert_eq!(got.typ, expect.typ);
|
|
assert_eq!(got.version, expect.version);
|
|
|
|
let (mut left, mut right) = (Vec::new(), Vec::new());
|
|
got.payload.encode(&mut left);
|
|
expect.payload.encode(&mut right);
|
|
|
|
assert_eq!(left, right);
|
|
}
|
|
|
|
#[test]
|
|
fn split() {
|
|
// Check we split two handshake messages within one PDU.
|
|
let mut hj = HandshakeJoiner::new();
|
|
|
|
// two HelloRequests
|
|
let msg = Message {
|
|
typ: ContentType::Handshake,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::new_opaque(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
|
|
};
|
|
|
|
assert_eq!(hj.want_message(&msg), true);
|
|
assert_eq!(hj.take_message(msg), Some(2));
|
|
assert_eq!(hj.is_empty(), true);
|
|
|
|
let expect = Message {
|
|
typ: ContentType::Handshake,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::Handshake(HandshakeMessagePayload {
|
|
typ: HandshakeType::HelloRequest,
|
|
payload: HandshakePayload::HelloRequest,
|
|
}),
|
|
};
|
|
|
|
pop_eq(&expect, &mut hj);
|
|
pop_eq(&expect, &mut hj);
|
|
}
|
|
|
|
#[test]
|
|
fn broken() {
|
|
// Check obvious crap payloads are reported as errors, not panics.
|
|
let mut hj = HandshakeJoiner::new();
|
|
|
|
// short ClientHello
|
|
let msg = Message {
|
|
typ: ContentType::Handshake,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::new_opaque(b"\x01\x00\x00\x02\xff\xff".to_vec()),
|
|
};
|
|
|
|
assert_eq!(hj.want_message(&msg), true);
|
|
assert_eq!(hj.take_message(msg), None);
|
|
}
|
|
|
|
#[test]
|
|
fn join() {
|
|
// Check we join one handshake message split over two PDUs.
|
|
let mut hj = HandshakeJoiner::new();
|
|
assert_eq!(hj.is_empty(), true);
|
|
|
|
// Introduce Finished of 16 bytes, providing 4.
|
|
let mut msg = Message {
|
|
typ: ContentType::Handshake,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::new_opaque(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()),
|
|
};
|
|
|
|
assert_eq!(hj.want_message(&msg), true);
|
|
assert_eq!(hj.take_message(msg), Some(0));
|
|
assert_eq!(hj.is_empty(), false);
|
|
|
|
// 11 more bytes.
|
|
msg = Message {
|
|
typ: ContentType::Handshake,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::new_opaque(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()),
|
|
};
|
|
|
|
assert_eq!(hj.want_message(&msg), true);
|
|
assert_eq!(hj.take_message(msg), Some(0));
|
|
assert_eq!(hj.is_empty(), false);
|
|
|
|
// Final 1 byte.
|
|
msg = Message {
|
|
typ: ContentType::Handshake,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::new_opaque(b"\x0f".to_vec()),
|
|
};
|
|
|
|
assert_eq!(hj.want_message(&msg), true);
|
|
assert_eq!(hj.take_message(msg), Some(1));
|
|
assert_eq!(hj.is_empty(), true);
|
|
|
|
let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
|
|
let expect = Message {
|
|
typ: ContentType::Handshake,
|
|
version: ProtocolVersion::TLSv1_2,
|
|
payload: MessagePayload::Handshake(HandshakeMessagePayload {
|
|
typ: HandshakeType::Finished,
|
|
payload: HandshakePayload::Finished(Payload::new(payload)),
|
|
}),
|
|
};
|
|
|
|
pop_eq(&expect, &mut hj);
|
|
}
|
|
}
|