we can pass application data :)

This commit is contained in:
Joseph Birr-Pixton 2016-05-27 21:47:13 +01:00
parent 5d9ae0287d
commit ed0ec5441f
6 changed files with 147 additions and 22 deletions

View File

@ -4,6 +4,10 @@ use std::process;
extern crate mio;
use mio::tcp::TcpStream;
use std::str;
use std::io;
use std::io::{Read, Write};
extern crate rustls;
const CLIENT: mio::Token = mio::Token(0);
@ -48,6 +52,22 @@ impl mio::Handler for TlsClient {
}
}
impl io::Write for TlsClient {
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
self.tls_session.write(bytes)
}
fn flush(&mut self) -> io::Result<()> {
self.tls_session.flush()
}
}
impl io::Read for TlsClient {
fn read(&mut self, bytes: &mut [u8]) -> io::Result<usize> {
self.tls_session.read(bytes)
}
}
fn read_file(filename: &str) -> Vec<u8> {
use std::io::Read;
@ -92,6 +112,13 @@ impl TlsClient {
self.closing = true;
return;
}
/* We might have new plaintext as a result. */
let mut plaintext = Vec::new();
self.tls_session.read_to_end(&mut plaintext).unwrap();
if plaintext.len() > 0 {
println!("got {}", str::from_utf8(&plaintext).unwrap());
}
}
fn do_write(&mut self) {
@ -138,6 +165,7 @@ fn main() {
let sock = TcpStream::connect(&addr).unwrap();
let mut event_loop = mio::EventLoop::new().unwrap();
let mut tlsclient = TlsClient::new(sock);
tlsclient.write(b"GET / HTTP/1.0\r\n\r\n").unwrap();
tlsclient.register(&mut event_loop);
event_loop.run(&mut tlsclient).unwrap();
}

View File

@ -6,6 +6,7 @@ use msgs::handshake::{CertificatePayload, ClientExtension, DigitallySignedStruct
use msgs::deframer::MessageDeframer;
use msgs::fragmenter::{MessageFragmenter, MAX_FRAGMENT_LEN};
use msgs::message::Message;
use msgs::base::Payload;
use client_hs;
use hash_hs;
use verify;
@ -72,6 +73,7 @@ impl ClientHandshakeData {
}
}
#[derive(PartialEq)]
pub enum ConnState {
ExpectServerHello,
ExpectCertificate,
@ -101,7 +103,8 @@ pub struct ClientSession {
read_seq: u64,
pub message_deframer: MessageDeframer,
pub message_fragmenter: MessageFragmenter,
pub plain_buf: Vec<u8>,
pub sendable_plaintext: Vec<u8>,
pub received_plaintext: Vec<u8>,
pub tls_queue: VecDeque<Message>,
pub state: ConnState
}
@ -118,7 +121,8 @@ impl ClientSession {
read_seq: 0,
message_deframer: MessageDeframer::new(),
message_fragmenter: MessageFragmenter::new(MAX_FRAGMENT_LEN),
plain_buf: Vec::new(),
sendable_plaintext: Vec::new(),
received_plaintext: Vec::new(),
tls_queue: VecDeque::new(),
state: ConnState::ExpectServerHello
};
@ -192,6 +196,11 @@ impl ClientSession {
let new_state = try!((handler.handle)(self, msg));
self.state = new_state;
/* Once we're connected, start flushing sendable_plaintext. */
if self.state == ConnState::Traffic {
self.flush_plaintext();
}
Ok(())
}
@ -203,7 +212,7 @@ impl ClientSession {
ConnState::ExpectServerHelloDone => &client_hs::EXPECT_SERVER_HELLO_DONE,
ConnState::ExpectCCS => &client_hs::EXPECT_CCS,
ConnState::ExpectFinished => &client_hs::EXPECT_FINISHED,
_ => &client_hs::INVALID_STATE
ConnState::Traffic => &client_hs::TRAFFIC
}
}
@ -239,11 +248,65 @@ impl ClientSession {
}
pub fn send_plain(&mut self, data: &[u8]) {
use msgs::enums::{ContentType, ProtocolVersion};
use msgs::message::MessagePayload;
if self.state != ConnState::Traffic {
/* If we haven't completed handshaking, buffer
* plaintext to send once we do. */
self.sendable_plaintext.extend_from_slice(data);
return;
}
assert!(self.state.is_encrypted());
/* Make one giant message, then have the fragmenter chop
* it into bits. Then encrypt and queue those bits. */
let m = Message {
typ: ContentType::ApplicationData,
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::opaque(data.to_vec())
};
let mut plain_messages = VecDeque::new();
self.message_fragmenter.fragment(&m, &mut plain_messages);
for m in plain_messages {
let em = self.encrypt_outgoing(&m);
self.tls_queue.push_back(em);
}
}
pub fn flush(&mut self) {
let buf = mem::replace(&mut self.plain_buf, Vec::new());
pub fn flush_plaintext(&mut self) {
if self.state != ConnState::Traffic {
return;
}
let buf = mem::replace(&mut self.sendable_plaintext, Vec::new());
self.send_plain(&buf);
}
pub fn take_received_plaintext(&mut self, bytes: Payload) {
self.received_plaintext.extend_from_slice(&bytes.body);
}
}
impl io::Read for ClientSession {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = try!(self.received_plaintext.as_slice().read(buf));
self.received_plaintext.drain(0..len);
Ok(len)
}
}
impl io::Write for ClientSession {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.send_plain(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_plaintext();
Ok(())
}
}

View File

@ -331,9 +331,22 @@ fn handle_finished(sess: &mut ClientSession, m: &Message) -> Result<ConnState, H
.ok_or(HandshakeError::DecryptError));
dm.decode_payload();
let finished = extract_handshake!(dm, HandshakePayload::Finished);
let finished = try!(extract_handshake!(dm, HandshakePayload::Finished)
.ok_or(HandshakeError::General("finished message missing".to_string()))
);
/* Work out what verify_data we expect. */
let vh = sess.handshake_data.get_verify_hash();
let expect_verify_data = sess.secrets_current.server_verify_data(&vh);
/* Constant-time verification of this is relatively unimportant: they only
* get one chance. But it can't hurt. */
use ring;
ring::constant_time::verify_slices_are_equal(&expect_verify_data, &finished.body)
.map_err(|_| HandshakeError::DecryptError)
.unwrap();
println!("got finished {:?}", finished);
sess.flush();
Ok(ConnState::Traffic)
}
@ -342,6 +355,26 @@ pub static EXPECT_FINISHED: Handler = Handler {
handle: handle_finished
};
/* -- Traffic transit state -- */
fn expect_traffic() -> Expectation {
Expectation {
content_types: vec![ContentType::ApplicationData],
handshake_types: Vec::new()
}
}
fn handle_traffic(sess: &mut ClientSession, m: &Message) -> Result<ConnState, HandshakeError> {
let dm = try!(sess.decrypt_incoming(m)
.ok_or(HandshakeError::DecryptError));
sess.take_received_plaintext(dm.get_opaque_payload().unwrap());
Ok(ConnState::Traffic)
}
pub static TRAFFIC: Handler = Handler {
expect: expect_traffic,
handle: handle_traffic
};
/* -- Generic invalid state -- */
fn expect_invalid() -> Expectation {
Expectation {

View File

@ -46,10 +46,7 @@ mod tests {
assert_eq!(&m.typ, typ);
assert_eq!(&m.version, version);
match m.payload {
MessagePayload::Unknown(ref pl) => assert_eq!(pl.body.to_vec(), bytes.to_vec()),
_ => unreachable!()
};
assert_eq!(m.get_opaque_payload().unwrap().body.to_vec(), bytes.to_vec());
}
#[test]

View File

@ -11,7 +11,7 @@ pub enum MessagePayload {
Alert(AlertMessagePayload),
Handshake(HandshakeMessagePayload),
ChangeCipherSpec(ChangeCipherSpecPayload),
Unknown(Payload)
Opaque(Payload)
}
impl MessagePayload {
@ -20,12 +20,12 @@ impl MessagePayload {
MessagePayload::Alert(ref x) => x.encode(bytes),
MessagePayload::Handshake(ref x) => x.encode(bytes),
MessagePayload::ChangeCipherSpec(ref x) => x.encode(bytes),
MessagePayload::Unknown(ref x) => x.encode(bytes)
MessagePayload::Opaque(ref x) => x.encode(bytes)
}
}
pub fn decode_given_type(&self, typ: &ContentType) -> Option<MessagePayload> {
if let MessagePayload::Unknown(ref payload) = *self {
if let MessagePayload::Opaque(ref payload) = *self {
let mut r = Reader::init(&payload.body);
match *typ {
ContentType::Alert =>
@ -43,7 +43,7 @@ impl MessagePayload {
}
pub fn opaque(data: Vec<u8>) -> MessagePayload {
MessagePayload::Unknown(Payload { body: data.into_boxed_slice() })
MessagePayload::Opaque(Payload { body: data.into_boxed_slice() })
}
}
@ -64,7 +64,7 @@ impl Message {
let mut sub = try_ret!(r.sub(len as usize));
let payload = try_ret!(Payload::read(&mut sub));
Some(Message { typ: typ, version: version, payload: MessagePayload::Unknown(payload) })
Some(Message { typ: typ, version: version, payload: MessagePayload::Opaque(payload) })
}
pub fn encode(&self, bytes: &mut Vec<u8>) {
@ -86,4 +86,12 @@ impl Message {
self.payload = x;
}
}
pub fn get_opaque_payload(&self) -> Option<Payload> {
if let MessagePayload::Opaque(ref op) = self.payload {
Some(op.clone())
} else {
None
}
}
}

View File

@ -164,12 +164,8 @@ fn dumphex(why: &str, buf: &[u8]) {
impl MessageCipher for GCMMessageCipher {
fn decrypt(&self, msg: &Message, seq: u64) -> Result<Message, ()> {
let mut buf = try!({
match msg.payload {
MessagePayload::Unknown(ref payload) => Ok(payload.body.to_vec()),
_ => Err(())
}
});
let payload = try!(msg.get_opaque_payload().ok_or(()));
let mut buf = payload.body.to_vec();
if buf.len() < GCM_OVERHEAD {
return Err(());