Add helper function to help with blocking IO

Add the simplest client example to demonstrate its use.
This commit is contained in:
Joseph Birr-Pixton 2017-06-10 10:28:26 +01:00
parent 436d9e924f
commit 7d284bfea2
4 changed files with 216 additions and 1 deletions

View File

@ -51,3 +51,7 @@ path = "examples/tlsclient.rs"
[[example]]
name = "tlsserver"
path = "examples/tlsserver.rs"
[[example]]
name = "simpleclient"
path = "examples/simpleclient.rs"

36
examples/simpleclient.rs Normal file
View File

@ -0,0 +1,36 @@
use std::sync::Arc;
use std::net::TcpStream;
use std::io::{Read, Write, stdout};
extern crate rustls;
extern crate webpki_roots;
use rustls::Session;
fn main() {
let mut config = rustls::ClientConfig::new();
config.root_store.add_trust_anchors(&webpki_roots::ROOTS);
let mut tls = rustls::ClientSession::new(&Arc::new(config), "google.com");
tls.write(concat!("GET / HTTP/1.1\r\n",
"Host: google.com\r\n",
"Connection: close\r\n",
"Accept-Encoding: identity\r\n",
"\r\n")
.as_bytes())
.unwrap();
let mut sock = TcpStream::connect("google.com:443").unwrap();
loop {
let (rl, wl) = tls.complete_io(&mut sock).unwrap();
if rl == 0 && wl == 0 {
println!("EOF");
break;
}
let mut plaintext = Vec::new();
tls.read_to_end(&mut plaintext).unwrap();
stdout().write_all(&plaintext).unwrap();
}
}

View File

@ -96,6 +96,57 @@ pub trait Session: Read + Write + Send {
///
/// This returns None until the version is agreed.
fn get_protocol_version(&self) -> Option<ProtocolVersion>;
/// This function uses `io` to complete any outstanding IO for
/// this session.
///
/// This is a convenience function which solely uses other parts
/// of the public API.
///
/// What this means depends on the session state:
///
/// - If the session `is_handshaking()`, then IO is performed until
/// the handshake is complete.
/// - Otherwise, if `wants_write` is true, `write_tls` is invoked
/// until it is all written.
/// - Otherwise, if `wants_read` is true, `read_tls` is invoked
/// once.
///
/// The return value is the number of bytes read from and written
/// to `io`, respectively.
///
/// This function will block if `io` blocks.
///
/// Errors from TLS record handling (ie, from `process_new_packets()`)
/// are wrapped in an `io::ErrorKind::InvalidData`-kind error.
fn complete_io<T>(&mut self, io: &mut T) -> Result<(usize, usize), io::Error>
where Self: Sized, T: Read + Write
{
let until_handshaked = self.is_handshaking();
let mut wrlen = 0;
let mut rdlen = 0;
loop {
while self.wants_write() {
wrlen += self.write_tls(io)?;
}
if !until_handshaked && wrlen > 0 {
return Ok((rdlen, wrlen));
}
rdlen += self.read_tls(io)?;
self.process_new_packets()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
match (until_handshaked, self.is_handshaking()) {
(true, false) => return Ok((rdlen, wrlen)),
(false, _) => return Ok((rdlen, wrlen)),
(_, _) => {}
};
}
}
}
#[derive(Clone, Debug)]

View File

@ -2,7 +2,7 @@
use std::sync::Arc;
use std::sync::atomic;
use std::fs;
use std::io::{self, Write};
use std::io::{self, Write, Read};
extern crate rustls;
use rustls::{ClientConfig, ClientSession, ResolvesClientCert};
@ -558,3 +558,127 @@ fn client_respects_buffer_limit_post_handshake() {
check_read(&mut server, b"01234567890123456789012345");
}
struct OtherSession<'a> {
sess: &'a mut Session,
pub reads: usize,
pub writes: usize,
}
impl<'a> OtherSession<'a> {
fn new(sess: &'a mut Session) -> OtherSession<'a> {
OtherSession { sess, reads: 0, writes: 0 }
}
}
impl<'a> io::Read for OtherSession<'a> {
fn read(&mut self, mut b: &mut [u8]) -> io::Result<usize> {
self.reads += 1;
self.sess.write_tls(b.by_ref())
}
}
impl<'a> io::Write for OtherSession<'a> {
fn write(&mut self, mut b: &[u8]) -> io::Result<usize> {
self.writes += 1;
let l = self.sess.read_tls(b.by_ref())?;
self.sess.process_new_packets().unwrap();
Ok(l)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[test]
fn client_complete_io_for_handshake() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost");
let mut server = ServerSession::new(&Arc::new(make_server_config()));
assert_eq!(true, client.is_handshaking());
let (rdlen, wrlen) = client.complete_io(&mut OtherSession::new(&mut server)).unwrap();
assert!(rdlen > 0 && wrlen > 0);
assert_eq!(false, client.is_handshaking());
}
#[test]
fn client_complete_io_for_write() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost");
let mut server = ServerSession::new(&Arc::new(make_server_config()));
do_handshake(&mut client, &mut server);
client.write(b"01234567890123456789").unwrap();
client.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut server);
let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
assert!(rdlen == 0 && wrlen > 0);
assert_eq!(pipe.writes, 2);
}
check_read(&mut server, b"0123456789012345678901234567890123456789");
}
#[test]
fn client_complete_io_for_read() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost");
let mut server = ServerSession::new(&Arc::new(make_server_config()));
do_handshake(&mut client, &mut server);
server.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut server);
let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
assert!(rdlen > 0 && wrlen == 0);
assert_eq!(pipe.reads, 1);
}
check_read(&mut client, b"01234567890123456789");
}
#[test]
fn server_complete_io_for_handshake() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost");
let mut server = ServerSession::new(&Arc::new(make_server_config()));
assert_eq!(true, server.is_handshaking());
let (rdlen, wrlen) = server.complete_io(&mut OtherSession::new(&mut client)).unwrap();
assert!(rdlen > 0 && wrlen > 0);
assert_eq!(false, server.is_handshaking());
}
#[test]
fn server_complete_io_for_write() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost");
let mut server = ServerSession::new(&Arc::new(make_server_config()));
do_handshake(&mut client, &mut server);
server.write(b"01234567890123456789").unwrap();
server.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut client);
let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap();
assert!(rdlen == 0 && wrlen > 0);
assert_eq!(pipe.writes, 2);
}
check_read(&mut client, b"0123456789012345678901234567890123456789");
}
#[test]
fn server_complete_io_for_read() {
let mut client = ClientSession::new(&Arc::new(make_client_config()), "localhost");
let mut server = ServerSession::new(&Arc::new(make_server_config()));
do_handshake(&mut client, &mut server);
client.write(b"01234567890123456789").unwrap();
{
let mut pipe = OtherSession::new(&mut client);
let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap();
assert!(rdlen > 0 && wrlen == 0);
assert_eq!(pipe.reads, 1);
}
check_read(&mut server, b"01234567890123456789");
}