Reconstitute handshake messages

also delete quite a bit of printf tracing
This commit is contained in:
Joseph Birr-Pixton 2016-05-30 19:56:00 +01:00
parent 3b02040431
commit c207843eb4
6 changed files with 285 additions and 45 deletions

View File

@ -5,6 +5,7 @@ use suites::{SupportedCipherSuite, DEFAULT_CIPHERSUITES};
use msgs::handshake::{CertificatePayload, DigitallySignedStruct};
use msgs::enums::{ContentType, AlertDescription, AlertLevel};
use msgs::deframer::MessageDeframer;
use msgs::hsjoiner::HandshakeJoiner;
use msgs::fragmenter::{MessageFragmenter, MAX_FRAGMENT_LEN};
use msgs::message::{Message, MessagePayload};
use msgs::base::Payload;
@ -104,6 +105,7 @@ pub struct ClientSession {
read_seq: u64,
peer_eof: bool,
pub message_deframer: MessageDeframer,
pub handshake_joiner: HandshakeJoiner,
pub message_fragmenter: MessageFragmenter,
pub sendable_plaintext: Vec<u8>,
pub received_plaintext: Vec<u8>,
@ -123,6 +125,7 @@ impl ClientSession {
read_seq: 0,
peer_eof: false,
message_deframer: MessageDeframer::new(),
handshake_joiner: HandshakeJoiner::new(),
message_fragmenter: MessageFragmenter::new(MAX_FRAGMENT_LEN),
sendable_plaintext: Vec::new(),
received_plaintext: Vec::new(),
@ -186,14 +189,6 @@ impl ClientSession {
}
fn process_alert(&mut self, msg: &mut Message) -> Result<(), HandshakeError> {
/* Decrypt it if needed. */
if self.state.is_encrypted() {
let mut dm = try!(self.decrypt_incoming(msg)
.ok_or(HandshakeError::DecryptError));
dm.decode_payload();
*msg = dm;
}
/* Log it. */
println!("Alert received: {:?}", msg);
@ -217,14 +212,45 @@ impl ClientSession {
}
pub fn process_msg(&mut self, msg: &mut Message) -> Result<(), HandshakeError> {
if !self.state.is_encrypted() {
msg.decode_payload();
/* Decrypt if demanded by current state. */
if self.state.is_encrypted() {
let dm = try!(self.decrypt_incoming(msg)
.ok_or(HandshakeError::DecryptError));
*msg = dm;
}
/* For handshake messages, we need to join them before parsing
* and processing. */
if self.handshake_joiner.want_message(msg) {
self.handshake_joiner.take_message(msg);
return self.process_new_handshake_messages();
}
/* Now we can fully parse the message payload. */
msg.decode_payload();
/* For alerts, we have separate logic. */
if msg.is_content_type(ContentType::Alert) {
return self.process_alert(msg);
}
return self.process_main_protocol(msg);
}
fn process_new_handshake_messages(&mut self) -> Result<(), HandshakeError> {
loop {
match self.handshake_joiner.frames.pop_front() {
Some(mut msg) => try!(self.process_main_protocol(&mut msg)),
None => break
}
}
Ok(())
}
/// Process `msg`. First, we get the current `Handler`. Then we ask what
/// that Handler expects. Finally, we ask the handler to handle the message.
fn process_main_protocol(&mut self, msg: &mut Message) -> Result<(), HandshakeError> {
let handler = self.get_handler();
let expects = (handler.expect)();
try!(expects.check_message(msg));
@ -274,11 +300,8 @@ impl ClientSession {
let mut data = Vec::new();
let msg = msg_maybe.unwrap();
println!("writing {:?}", msg);
msg.encode(&mut data);
println!("write {:?}", data);
wr.write_all(&data)
}
@ -323,6 +346,13 @@ impl ClientSession {
pub fn take_received_plaintext(&mut self, bytes: Payload) {
self.received_plaintext.extend_from_slice(&bytes.body);
}
/// Are we done? ie, have we processed all received messages,
/// and received a close_notify to indicate that no new messages
/// will arrive?
fn connection_at_eof(&self) -> bool {
self.peer_eof && !self.message_deframer.has_pending()
}
}
impl io::Read for ClientSession {
@ -330,10 +360,10 @@ impl io::Read for ClientSession {
let len = try!(self.received_plaintext.as_slice().read(buf));
self.received_plaintext.drain(0..len);
if len == 0 && self.peer_eof {
if len == 0 && self.connection_at_eof() && self.received_plaintext.len() == 0 {
return Err(io::Error::new(io::ErrorKind::ConnectionAborted, "CloseNotify alert received"));
}
Ok(len)
}
}

View File

@ -81,8 +81,6 @@ fn expect_server_hello() -> Expectation {
fn handle_server_hello(sess: &mut ClientSession, m: &Message) -> Result<ConnState, HandshakeError> {
let server_hello = extract_handshake!(m, HandshakePayload::ServerHello).unwrap();
println!("we have server hello {:?}", server_hello);
if server_hello.server_version != ProtocolVersion::TLSv1_2 {
return Err(HandshakeError::General("server does not support TLSv1_2".to_string()));
}
@ -129,7 +127,6 @@ fn handle_certificate(sess: &mut ClientSession, m: &Message) -> Result<ConnState
let cert_chain = extract_handshake!(m, HandshakePayload::Certificate).unwrap();
sess.handshake_data.hash_message(m);
sess.handshake_data.server_cert_chain = cert_chain.clone();
println!("we have server cert {:?}", cert_chain);
Ok(ConnState::ExpectServerKX)
}
@ -155,7 +152,6 @@ fn handle_server_kx(sess: &mut ClientSession, m: &Message) -> Result<ConnState,
}
let decoded_kx = maybe_decoded_kx.unwrap();
println!("we have serverkx {:?}", decoded_kx);
/* Save the signature and signed parameters for later verification. */
sess.handshake_data.server_kx_sig = decoded_kx.get_sig();
@ -198,8 +194,6 @@ fn emit_clientkx(sess: &mut ClientSession, kxd: &suites::KeyExchangeResult) {
)
};
println!("sending ckx {:?}", ckx);
sess.handshake_data.hash_message(&ckx);
sess.tls_queue.push_back(ckx);
}
@ -238,7 +232,6 @@ fn emit_finished(sess: &mut ClientSession) {
}
fn handle_server_hello_done(sess: &mut ClientSession, m: &Message) -> Result<ConnState, HandshakeError> {
println!("we have serverhellodone");
sess.handshake_data.hash_message(m);
/* 1. Verify the cert chain.
@ -309,7 +302,6 @@ fn expect_ccs() -> Expectation {
fn handle_ccs(_sess: &mut ClientSession, _m: &Message) -> Result<ConnState, HandshakeError> {
/* nb. msgs layer validates trivial contents of CCS */
println!("got server CCS");
Ok(ConnState::ExpectFinished)
}
@ -327,11 +319,7 @@ fn expect_finished() -> Expectation {
}
fn handle_finished(sess: &mut ClientSession, m: &Message) -> Result<ConnState, HandshakeError> {
let mut dm = try!(sess.decrypt_incoming(m)
.ok_or(HandshakeError::DecryptError));
dm.decode_payload();
let finished = try!(extract_handshake!(dm, HandshakePayload::Finished)
let finished = try!(extract_handshake!(m, HandshakePayload::Finished)
.ok_or(HandshakeError::General("finished message missing".to_string()))
);
@ -346,7 +334,6 @@ fn handle_finished(sess: &mut ClientSession, m: &Message) -> Result<ConnState, H
.map_err(|_| HandshakeError::DecryptError)
.unwrap();
println!("got finished {:?}", finished);
Ok(ConnState::Traffic)
}
@ -364,9 +351,7 @@ fn expect_traffic() -> Expectation {
}
fn handle_traffic(sess: &mut ClientSession, m: &Message) -> Result<ConnState, HandshakeError> {
let dm = try!(sess.decrypt_incoming(m)
.ok_or(HandshakeError::DecryptError));
sess.take_received_plaintext(dm.get_opaque_payload().unwrap());
sess.take_received_plaintext(m.get_opaque_payload().unwrap());
Ok(ConnState::Traffic)
}

View File

@ -8,15 +8,6 @@ pub struct HandshakeHash {
ctx: digest::Context
}
fn dump(label: &str, bytes: &[u8]) {
print!("{}: ", label);
for b in bytes {
print!("{:02x}", b);
}
println!("");
}
impl HandshakeHash {
pub fn new(alg: &'static digest::Algorithm) -> HandshakeHash {
HandshakeHash { ctx: digest::Context::new(alg) }
@ -27,8 +18,6 @@ impl HandshakeHash {
MessagePayload::Handshake(ref hs) => {
let mut buf = Vec::new();
hs.encode(&mut buf);
println!("hash msg {:?} {} bytes", hs.typ, buf.len());
dump("hash", &buf);
self.ctx.update(&buf);
},
_ => unreachable!()
@ -37,8 +26,6 @@ impl HandshakeHash {
}
pub fn update_raw(&mut self, buf: &[u8]) -> &mut HandshakeHash {
println!("hash raw {} bytes", buf.len());
dump("hash init", buf);
self.ctx.update(buf);
self
}

View File

@ -7,9 +7,18 @@ use msgs::message::Message;
static HEADER_SIZE: usize = 1 + 2 + 2;
/// This deframer works to reconstruct TLS messages
/// from arbitrary-sized reads, buffering as neccessary.
/// The input is `read()`, the output is the `frames` deque.
pub struct MessageDeframer {
/// Completed frames for output.
pub frames: VecDeque<Message>,
/// A variable-size buffer containing the currently-
/// accumulating TLS message.
buf: Vec<u8>,
/// A buffer into which we read.
chunk: [u8; 2048]
}
@ -22,6 +31,9 @@ impl MessageDeframer {
}
}
/// Read some bytes from `rd`, and add them to our internal
/// buffer. If this means our internal buffer contains
/// full messages, decode them all.
pub fn read(&mut self, rd: &mut io::Read) -> io::Result<usize> {
let rc = rd.read(&mut self.chunk);
@ -39,11 +51,22 @@ impl MessageDeframer {
Ok(len)
}
/// Returns true if we have messages for the caller
/// to process, either whole messages in our output
/// queue or partial messages in our buffer.
pub fn has_pending(&self) -> bool {
self.frames.len() > 0 || self.buf.len() > 0
}
/// Does our `buf` contain a full message? It does if it is big enough to
/// contain a header, and that header has a length which falls within `buf`.
fn buf_contains_message(&self) -> bool {
self.buf.len() >= HEADER_SIZE &&
self.buf.len() >= (codec::decode_u16(&self.buf[3..5]).unwrap() as usize) + HEADER_SIZE
}
/// Take a TLS message off the front of `buf`, and put it onto the back
/// of our `frames` deque.
fn deframe_one(&mut self) {
let used = {
let mut rd = codec::Reader::init(&self.buf);
@ -100,6 +123,7 @@ mod tests {
for i in 0..bytes.len() {
assert_len(1, input_bytes(d, &bytes[i..i+1]));
assert_eq!(d.has_pending(), true);
if i < bytes.len() - 1 {
assert_eq!(frames_before, d.frames.len());
@ -132,36 +156,49 @@ mod tests {
#[test]
fn check_incremental() {
let mut d = MessageDeframer::new();
assert_eq!(d.has_pending(), false);
input_whole_incremental(&mut d, FIRST_MESSAGE);
assert_eq!(d.has_pending(), true);
assert_eq!(1, d.frames.len());
pop_first(&mut d);
assert_eq!(d.has_pending(), false);
}
#[test]
fn check_incremental_2() {
let mut d = MessageDeframer::new();
assert_eq!(d.has_pending(), false);
input_whole_incremental(&mut d, FIRST_MESSAGE);
assert_eq!(d.has_pending(), true);
input_whole_incremental(&mut d, SECOND_MESSAGE);
assert_eq!(d.has_pending(), true);
assert_eq!(2, d.frames.len());
pop_first(&mut d);
assert_eq!(d.has_pending(), true);
pop_second(&mut d);
assert_eq!(d.has_pending(), false);
}
#[test]
fn check_whole() {
let mut d = MessageDeframer::new();
assert_eq!(d.has_pending(), false);
assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
assert_eq!(d.has_pending(), true);
assert_eq!(d.frames.len(), 1);
pop_first(&mut d);
assert_eq!(d.has_pending(), false);
}
#[test]
fn check_whole_2() {
let mut d = MessageDeframer::new();
assert_eq!(d.has_pending(), false);
assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE));
assert_eq!(d.frames.len(), 2);
pop_first(&mut d);
pop_second(&mut d);
assert_eq!(d.has_pending(), false);
}
}

200
src/msgs/hsjoiner.rs Normal file
View File

@ -0,0 +1,200 @@
use std::collections::VecDeque;
use msgs::codec;
use msgs::codec::Codec;
use msgs::message::{Message, MessagePayload};
use msgs::enums::{ContentType, ProtocolVersion};
use msgs::handshake::HandshakeMessagePayload;
const HEADER_SIZE: usize = 1 + 3;
/// This works to reconstruct TLS handshake messages
/// from individual TLS messages. It's guaranteed that
/// TLS messages output from this layer contain precisely
/// one handshake payload.
pub struct HandshakeJoiner {
/// Completed handshake frames for output.
pub frames: VecDeque<Message>,
/// The message payload we're currently accumulating.
buf: Vec<u8>
}
impl HandshakeJoiner {
pub fn new() -> HandshakeJoiner{
HandshakeJoiner {
frames: VecDeque::new(),
buf: Vec::new()
}
}
/// Do we want to process this message?
pub fn want_message(&self, msg: &Message) -> bool {
msg.is_content_type(ContentType::Handshake)
}
/// Take the message, and join/split it as needed.
/// Return the number of new messages added to the
/// output deque as a result of this message.
pub fn take_message(&mut self, msg: &Message) -> usize {
// Input must be opaque, otherwise we might have already
// lost information!
let payload = msg.get_opaque_payload().unwrap();
self.buf.extend_from_slice(&payload.body[..]);
let mut count = 0;
while self.buf_contains_message() {
self.deframe_one();
count += 1;
}
count
}
/// Does our `buf` contain a full handshake payload? It does if it is big
/// enough to contain a header, and that header has a length which falls
/// within `buf`.
fn buf_contains_message(&self) -> bool {
self.buf.len() >= HEADER_SIZE &&
self.buf.len() >= (codec::decode_u24(&self.buf[1..4]).unwrap() as usize) + HEADER_SIZE
}
/// Take a TLS handshake payload off the front of `buf`, and put it onto
/// the back of our `frames` deque inside a normal `Message`.
fn deframe_one(&mut self) {
let used = {
let mut rd = codec::Reader::init(&self.buf);
let payload = HandshakeMessagePayload::read(&mut rd).unwrap();
let m = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Handshake(payload)
};
self.frames.push_back(m);
rd.used()
};
self.buf.drain(..used);
}
}
#[cfg(test)]
mod tests {
use super::HandshakeJoiner;
use msgs::enums::{ProtocolVersion, ContentType, HandshakeType};
use msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
use msgs::message::{Message, MessagePayload};
use msgs::base::Payload;
#[test]
fn want() {
let hj = HandshakeJoiner::new();
let wanted = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::opaque(b"hello world".to_vec())
};
let unwanted = Message {
typ: ContentType::Alert,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::opaque(b"ponytown".to_vec())
};
assert_eq!(hj.want_message(&wanted), true);
assert_eq!(hj.want_message(&unwanted), false);
}
fn pop_eq(expect: &Message, hj: &mut HandshakeJoiner) {
let got = hj.frames.pop_front().unwrap();
assert_eq!(got.typ, expect.typ);
assert_eq!(got.version, expect.version);
let (mut left, mut right) = (Vec::new(), Vec::new());
got.payload.encode(&mut left);
expect.payload.encode(&mut right);
assert_eq!(left, right);
}
#[test]
fn split() {
/* Check we split two handshake messages within one PDU. */
let mut hj = HandshakeJoiner::new();
let msg = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::opaque(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()) /* two HelloRequests. */
};
assert_eq!(hj.want_message(&msg), true);
assert_eq!(hj.take_message(&msg), 2);
let expect = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Handshake(
HandshakeMessagePayload {
typ: HandshakeType::HelloRequest,
payload: HandshakePayload::HelloRequest
})
};
pop_eq(&expect, &mut hj);
pop_eq(&expect, &mut hj);
}
#[test]
fn join() {
/* Check we join one handshake message split over two PDUs. */
let mut hj = HandshakeJoiner::new();
/* Introduce Finished of 16 bytes, providing 4. */
let mut msg = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::opaque(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec())
};
assert_eq!(hj.want_message(&msg), true);
assert_eq!(hj.take_message(&msg), 0);
/* 11 more bytes. */
msg = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::opaque(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec())
};
assert_eq!(hj.want_message(&msg), true);
assert_eq!(hj.take_message(&msg), 0);
/* Final 1 byte. */
msg = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::opaque(b"\x0f".to_vec())
};
assert_eq!(hj.want_message(&msg), true);
assert_eq!(hj.take_message(&msg), 1);
let expect = Message {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::Handshake(
HandshakeMessagePayload {
typ: HandshakeType::Finished,
payload: HandshakePayload::Finished(
Payload { body: b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec().into_boxed_slice() }
)
}
)
};
pop_eq(&expect, &mut hj);
}
}

View File

@ -12,6 +12,7 @@ pub mod ccs;
pub mod message;
pub mod deframer;
pub mod fragmenter;
pub mod hsjoiner;
#[cfg(test)]
mod test {