diff --git a/Cargo.toml b/Cargo.toml index 5919aa2f..8d456375 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,3 +51,7 @@ path = "examples/tlsclient.rs" [[example]] name = "tlsserver" path = "examples/tlsserver.rs" + +[[example]] +name = "simpleclient" +path = "examples/simpleclient.rs" diff --git a/examples/simpleclient.rs b/examples/simpleclient.rs new file mode 100644 index 00000000..4068603a --- /dev/null +++ b/examples/simpleclient.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; + +use std::net::TcpStream; +use std::io::{Read, Write, stdout}; + +extern crate rustls; +extern crate webpki_roots; + +use rustls::Session; + +fn main() { + let mut config = rustls::ClientConfig::new(); + config.root_store.add_trust_anchors(&webpki_roots::ROOTS); + + let mut tls = rustls::ClientSession::new(&Arc::new(config), "google.com"); + tls.write(concat!("GET / HTTP/1.1\r\n", + "Host: google.com\r\n", + "Connection: close\r\n", + "Accept-Encoding: identity\r\n", + "\r\n") + .as_bytes()) + .unwrap(); + + let mut sock = TcpStream::connect("google.com:443").unwrap(); + loop { + let (rl, wl) = tls.complete_io(&mut sock).unwrap(); + if rl == 0 && wl == 0 { + println!("EOF"); + break; + } + + let mut plaintext = Vec::new(); + tls.read_to_end(&mut plaintext).unwrap(); + stdout().write_all(&plaintext).unwrap(); + } +} diff --git a/src/session.rs b/src/session.rs index 1089ed24..7a057077 100644 --- a/src/session.rs +++ b/src/session.rs @@ -96,6 +96,57 @@ pub trait Session: Read + Write + Send { /// /// This returns None until the version is agreed. fn get_protocol_version(&self) -> Option; + + /// This function uses `io` to complete any outstanding IO for + /// this session. + /// + /// This is a convenience function which solely uses other parts + /// of the public API. + /// + /// What this means depends on the session state: + /// + /// - If the session `is_handshaking()`, then IO is performed until + /// the handshake is complete. + /// - Otherwise, if `wants_write` is true, `write_tls` is invoked + /// until it is all written. + /// - Otherwise, if `wants_read` is true, `read_tls` is invoked + /// once. + /// + /// The return value is the number of bytes read from and written + /// to `io`, respectively. + /// + /// This function will block if `io` blocks. + /// + /// Errors from TLS record handling (ie, from `process_new_packets()`) + /// are wrapped in an `io::ErrorKind::InvalidData`-kind error. + fn complete_io(&mut self, io: &mut T) -> Result<(usize, usize), io::Error> + where Self: Sized, T: Read + Write + { + let until_handshaked = self.is_handshaking(); + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + while self.wants_write() { + wrlen += self.write_tls(io)?; + } + + if !until_handshaked && wrlen > 0 { + return Ok((rdlen, wrlen)); + } + + rdlen += self.read_tls(io)?; + + self.process_new_packets() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + match (until_handshaked, self.is_handshaking()) { + (true, false) => return Ok((rdlen, wrlen)), + (false, _) => return Ok((rdlen, wrlen)), + (_, _) => {} + }; + } + } } #[derive(Clone, Debug)] diff --git a/tests/api.rs b/tests/api.rs index 541f8ebe..a238a8b8 100644 --- a/tests/api.rs +++ b/tests/api.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use std::sync::atomic; use std::fs; -use std::io::{self, Write}; +use std::io::{self, Write, Read}; extern crate rustls; use rustls::{ClientConfig, ClientSession, ResolvesClientCert}; @@ -558,3 +558,127 @@ fn client_respects_buffer_limit_post_handshake() { check_read(&mut server, b"01234567890123456789012345"); } + +struct OtherSession<'a> { + sess: &'a mut Session, + pub reads: usize, + pub writes: usize, +} + +impl<'a> OtherSession<'a> { + fn new(sess: &'a mut Session) -> OtherSession<'a> { + OtherSession { sess, reads: 0, writes: 0 } + } +} + +impl<'a> io::Read for OtherSession<'a> { + fn read(&mut self, mut b: &mut [u8]) -> io::Result { + self.reads += 1; + self.sess.write_tls(b.by_ref()) + } +} + +impl<'a> io::Write for OtherSession<'a> { + fn write(&mut self, mut b: &[u8]) -> io::Result { + self.writes += 1; + let l = self.sess.read_tls(b.by_ref())?; + self.sess.process_new_packets().unwrap(); + Ok(l) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +#[test] +fn client_complete_io_for_handshake() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost"); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + assert_eq!(true, client.is_handshaking()); + let (rdlen, wrlen) = client.complete_io(&mut OtherSession::new(&mut server)).unwrap(); + assert!(rdlen > 0 && wrlen > 0); + assert_eq!(false, client.is_handshaking()); +} + +#[test] +fn client_complete_io_for_write() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost"); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + do_handshake(&mut client, &mut server); + + client.write(b"01234567890123456789").unwrap(); + client.write(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut server); + let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap(); + assert!(rdlen == 0 && wrlen > 0); + assert_eq!(pipe.writes, 2); + } + check_read(&mut server, b"0123456789012345678901234567890123456789"); +} + +#[test] +fn client_complete_io_for_read() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost"); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + do_handshake(&mut client, &mut server); + + server.write(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut server); + let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap(); + assert!(rdlen > 0 && wrlen == 0); + assert_eq!(pipe.reads, 1); + } + check_read(&mut client, b"01234567890123456789"); +} + +#[test] +fn server_complete_io_for_handshake() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost"); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + assert_eq!(true, server.is_handshaking()); + let (rdlen, wrlen) = server.complete_io(&mut OtherSession::new(&mut client)).unwrap(); + assert!(rdlen > 0 && wrlen > 0); + assert_eq!(false, server.is_handshaking()); +} + +#[test] +fn server_complete_io_for_write() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost"); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + do_handshake(&mut client, &mut server); + + server.write(b"01234567890123456789").unwrap(); + server.write(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut client); + let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap(); + assert!(rdlen == 0 && wrlen > 0); + assert_eq!(pipe.writes, 2); + } + check_read(&mut client, b"0123456789012345678901234567890123456789"); +} + +#[test] +fn server_complete_io_for_read() { + let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost"); + let mut server = ServerSession::new(&Arc::new(make_server_config())); + + do_handshake(&mut client, &mut server); + + client.write(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut client); + let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap(); + assert!(rdlen > 0 && wrlen == 0); + assert_eq!(pipe.reads, 1); + } + check_read(&mut server, b"01234567890123456789"); +}