mirror of https://github.com/ctz/rustls
we can pass application data :)
This commit is contained in:
parent
5d9ae0287d
commit
ed0ec5441f
|
@ -4,6 +4,10 @@ use std::process;
|
|||
extern crate mio;
|
||||
use mio::tcp::TcpStream;
|
||||
|
||||
use std::str;
|
||||
use std::io;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
extern crate rustls;
|
||||
|
||||
const CLIENT: mio::Token = mio::Token(0);
|
||||
|
@ -48,6 +52,22 @@ impl mio::Handler for TlsClient {
|
|||
}
|
||||
}
|
||||
|
||||
impl io::Write for TlsClient {
|
||||
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
|
||||
self.tls_session.write(bytes)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.tls_session.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Read for TlsClient {
|
||||
fn read(&mut self, bytes: &mut [u8]) -> io::Result<usize> {
|
||||
self.tls_session.read(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fn read_file(filename: &str) -> Vec<u8> {
|
||||
use std::io::Read;
|
||||
|
||||
|
@ -92,6 +112,13 @@ impl TlsClient {
|
|||
self.closing = true;
|
||||
return;
|
||||
}
|
||||
|
||||
/* We might have new plaintext as a result. */
|
||||
let mut plaintext = Vec::new();
|
||||
self.tls_session.read_to_end(&mut plaintext).unwrap();
|
||||
if plaintext.len() > 0 {
|
||||
println!("got {}", str::from_utf8(&plaintext).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
fn do_write(&mut self) {
|
||||
|
@ -138,6 +165,7 @@ fn main() {
|
|||
let sock = TcpStream::connect(&addr).unwrap();
|
||||
let mut event_loop = mio::EventLoop::new().unwrap();
|
||||
let mut tlsclient = TlsClient::new(sock);
|
||||
tlsclient.write(b"GET / HTTP/1.0\r\n\r\n").unwrap();
|
||||
tlsclient.register(&mut event_loop);
|
||||
event_loop.run(&mut tlsclient).unwrap();
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ use msgs::handshake::{CertificatePayload, ClientExtension, DigitallySignedStruct
|
|||
use msgs::deframer::MessageDeframer;
|
||||
use msgs::fragmenter::{MessageFragmenter, MAX_FRAGMENT_LEN};
|
||||
use msgs::message::Message;
|
||||
use msgs::base::Payload;
|
||||
use client_hs;
|
||||
use hash_hs;
|
||||
use verify;
|
||||
|
@ -72,6 +73,7 @@ impl ClientHandshakeData {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum ConnState {
|
||||
ExpectServerHello,
|
||||
ExpectCertificate,
|
||||
|
@ -101,7 +103,8 @@ pub struct ClientSession {
|
|||
read_seq: u64,
|
||||
pub message_deframer: MessageDeframer,
|
||||
pub message_fragmenter: MessageFragmenter,
|
||||
pub plain_buf: Vec<u8>,
|
||||
pub sendable_plaintext: Vec<u8>,
|
||||
pub received_plaintext: Vec<u8>,
|
||||
pub tls_queue: VecDeque<Message>,
|
||||
pub state: ConnState
|
||||
}
|
||||
|
@ -118,7 +121,8 @@ impl ClientSession {
|
|||
read_seq: 0,
|
||||
message_deframer: MessageDeframer::new(),
|
||||
message_fragmenter: MessageFragmenter::new(MAX_FRAGMENT_LEN),
|
||||
plain_buf: Vec::new(),
|
||||
sendable_plaintext: Vec::new(),
|
||||
received_plaintext: Vec::new(),
|
||||
tls_queue: VecDeque::new(),
|
||||
state: ConnState::ExpectServerHello
|
||||
};
|
||||
|
@ -192,6 +196,11 @@ impl ClientSession {
|
|||
let new_state = try!((handler.handle)(self, msg));
|
||||
self.state = new_state;
|
||||
|
||||
/* Once we're connected, start flushing sendable_plaintext. */
|
||||
if self.state == ConnState::Traffic {
|
||||
self.flush_plaintext();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -203,7 +212,7 @@ impl ClientSession {
|
|||
ConnState::ExpectServerHelloDone => &client_hs::EXPECT_SERVER_HELLO_DONE,
|
||||
ConnState::ExpectCCS => &client_hs::EXPECT_CCS,
|
||||
ConnState::ExpectFinished => &client_hs::EXPECT_FINISHED,
|
||||
_ => &client_hs::INVALID_STATE
|
||||
ConnState::Traffic => &client_hs::TRAFFIC
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -239,11 +248,65 @@ impl ClientSession {
|
|||
}
|
||||
|
||||
pub fn send_plain(&mut self, data: &[u8]) {
|
||||
use msgs::enums::{ContentType, ProtocolVersion};
|
||||
use msgs::message::MessagePayload;
|
||||
|
||||
if self.state != ConnState::Traffic {
|
||||
/* If we haven't completed handshaking, buffer
|
||||
* plaintext to send once we do. */
|
||||
self.sendable_plaintext.extend_from_slice(data);
|
||||
return;
|
||||
}
|
||||
|
||||
assert!(self.state.is_encrypted());
|
||||
|
||||
/* Make one giant message, then have the fragmenter chop
|
||||
* it into bits. Then encrypt and queue those bits. */
|
||||
let m = Message {
|
||||
typ: ContentType::ApplicationData,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: MessagePayload::opaque(data.to_vec())
|
||||
};
|
||||
|
||||
let mut plain_messages = VecDeque::new();
|
||||
self.message_fragmenter.fragment(&m, &mut plain_messages);
|
||||
|
||||
for m in plain_messages {
|
||||
let em = self.encrypt_outgoing(&m);
|
||||
self.tls_queue.push_back(em);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) {
|
||||
let buf = mem::replace(&mut self.plain_buf, Vec::new());
|
||||
pub fn flush_plaintext(&mut self) {
|
||||
if self.state != ConnState::Traffic {
|
||||
return;
|
||||
}
|
||||
|
||||
let buf = mem::replace(&mut self.sendable_plaintext, Vec::new());
|
||||
self.send_plain(&buf);
|
||||
}
|
||||
|
||||
pub fn take_received_plaintext(&mut self, bytes: Payload) {
|
||||
self.received_plaintext.extend_from_slice(&bytes.body);
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Read for ClientSession {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let len = try!(self.received_plaintext.as_slice().read(buf));
|
||||
self.received_plaintext.drain(0..len);
|
||||
Ok(len)
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for ClientSession {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.send_plain(buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.flush_plaintext();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -331,9 +331,22 @@ fn handle_finished(sess: &mut ClientSession, m: &Message) -> Result<ConnState, H
|
|||
.ok_or(HandshakeError::DecryptError));
|
||||
dm.decode_payload();
|
||||
|
||||
let finished = extract_handshake!(dm, HandshakePayload::Finished);
|
||||
let finished = try!(extract_handshake!(dm, HandshakePayload::Finished)
|
||||
.ok_or(HandshakeError::General("finished message missing".to_string()))
|
||||
);
|
||||
|
||||
/* Work out what verify_data we expect. */
|
||||
let vh = sess.handshake_data.get_verify_hash();
|
||||
let expect_verify_data = sess.secrets_current.server_verify_data(&vh);
|
||||
|
||||
/* Constant-time verification of this is relatively unimportant: they only
|
||||
* get one chance. But it can't hurt. */
|
||||
use ring;
|
||||
ring::constant_time::verify_slices_are_equal(&expect_verify_data, &finished.body)
|
||||
.map_err(|_| HandshakeError::DecryptError)
|
||||
.unwrap();
|
||||
|
||||
println!("got finished {:?}", finished);
|
||||
sess.flush();
|
||||
Ok(ConnState::Traffic)
|
||||
}
|
||||
|
||||
|
@ -342,6 +355,26 @@ pub static EXPECT_FINISHED: Handler = Handler {
|
|||
handle: handle_finished
|
||||
};
|
||||
|
||||
/* -- Traffic transit state -- */
|
||||
fn expect_traffic() -> Expectation {
|
||||
Expectation {
|
||||
content_types: vec![ContentType::ApplicationData],
|
||||
handshake_types: Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_traffic(sess: &mut ClientSession, m: &Message) -> Result<ConnState, HandshakeError> {
|
||||
let dm = try!(sess.decrypt_incoming(m)
|
||||
.ok_or(HandshakeError::DecryptError));
|
||||
sess.take_received_plaintext(dm.get_opaque_payload().unwrap());
|
||||
Ok(ConnState::Traffic)
|
||||
}
|
||||
|
||||
pub static TRAFFIC: Handler = Handler {
|
||||
expect: expect_traffic,
|
||||
handle: handle_traffic
|
||||
};
|
||||
|
||||
/* -- Generic invalid state -- */
|
||||
fn expect_invalid() -> Expectation {
|
||||
Expectation {
|
||||
|
|
|
@ -46,10 +46,7 @@ mod tests {
|
|||
|
||||
assert_eq!(&m.typ, typ);
|
||||
assert_eq!(&m.version, version);
|
||||
match m.payload {
|
||||
MessagePayload::Unknown(ref pl) => assert_eq!(pl.body.to_vec(), bytes.to_vec()),
|
||||
_ => unreachable!()
|
||||
};
|
||||
assert_eq!(m.get_opaque_payload().unwrap().body.to_vec(), bytes.to_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -11,7 +11,7 @@ pub enum MessagePayload {
|
|||
Alert(AlertMessagePayload),
|
||||
Handshake(HandshakeMessagePayload),
|
||||
ChangeCipherSpec(ChangeCipherSpecPayload),
|
||||
Unknown(Payload)
|
||||
Opaque(Payload)
|
||||
}
|
||||
|
||||
impl MessagePayload {
|
||||
|
@ -20,12 +20,12 @@ impl MessagePayload {
|
|||
MessagePayload::Alert(ref x) => x.encode(bytes),
|
||||
MessagePayload::Handshake(ref x) => x.encode(bytes),
|
||||
MessagePayload::ChangeCipherSpec(ref x) => x.encode(bytes),
|
||||
MessagePayload::Unknown(ref x) => x.encode(bytes)
|
||||
MessagePayload::Opaque(ref x) => x.encode(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_given_type(&self, typ: &ContentType) -> Option<MessagePayload> {
|
||||
if let MessagePayload::Unknown(ref payload) = *self {
|
||||
if let MessagePayload::Opaque(ref payload) = *self {
|
||||
let mut r = Reader::init(&payload.body);
|
||||
match *typ {
|
||||
ContentType::Alert =>
|
||||
|
@ -43,7 +43,7 @@ impl MessagePayload {
|
|||
}
|
||||
|
||||
pub fn opaque(data: Vec<u8>) -> MessagePayload {
|
||||
MessagePayload::Unknown(Payload { body: data.into_boxed_slice() })
|
||||
MessagePayload::Opaque(Payload { body: data.into_boxed_slice() })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,7 +64,7 @@ impl Message {
|
|||
let mut sub = try_ret!(r.sub(len as usize));
|
||||
let payload = try_ret!(Payload::read(&mut sub));
|
||||
|
||||
Some(Message { typ: typ, version: version, payload: MessagePayload::Unknown(payload) })
|
||||
Some(Message { typ: typ, version: version, payload: MessagePayload::Opaque(payload) })
|
||||
}
|
||||
|
||||
pub fn encode(&self, bytes: &mut Vec<u8>) {
|
||||
|
@ -86,4 +86,12 @@ impl Message {
|
|||
self.payload = x;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_opaque_payload(&self) -> Option<Payload> {
|
||||
if let MessagePayload::Opaque(ref op) = self.payload {
|
||||
Some(op.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -164,12 +164,8 @@ fn dumphex(why: &str, buf: &[u8]) {
|
|||
|
||||
impl MessageCipher for GCMMessageCipher {
|
||||
fn decrypt(&self, msg: &Message, seq: u64) -> Result<Message, ()> {
|
||||
let mut buf = try!({
|
||||
match msg.payload {
|
||||
MessagePayload::Unknown(ref payload) => Ok(payload.body.to_vec()),
|
||||
_ => Err(())
|
||||
}
|
||||
});
|
||||
let payload = try!(msg.get_opaque_payload().ok_or(()));
|
||||
let mut buf = payload.body.to_vec();
|
||||
|
||||
if buf.len() < GCM_OVERHEAD {
|
||||
return Err(());
|
||||
|
|
Loading…
Reference in New Issue