rustls/rustls-mio/examples/tlsserver.rs

639 lines
19 KiB
Rust

use std::sync::Arc;
use mio;
use mio::net::{TcpListener, TcpStream};
#[macro_use]
extern crate log;
use std::fs;
use std::io;
use std::net;
use std::io::{Write, Read, BufReader};
use std::collections::HashMap;
#[macro_use]
extern crate serde_derive;
use docopt::Docopt;
use env_logger;
use rustls;
use rustls::{RootCertStore, Session, NoClientAuth, AllowAnyAuthenticatedClient,
AllowAnyAnonymousOrAuthenticatedClient};
// Token for our listening socket.
const LISTENER: mio::Token = mio::Token(0);
// Which mode the server operates in.
#[derive(Clone)]
enum ServerMode {
/// Write back received bytes
Echo,
/// Do one read, then write a bodged HTTP response and
/// cleanly close the connection.
Http,
/// Forward traffic to/from given port on localhost.
Forward(u16),
}
/// This binds together a TCP listening socket, some outstanding
/// connections, and a TLS server configuration.
struct TlsServer {
server: TcpListener,
connections: HashMap<mio::Token, Connection>,
next_id: usize,
tls_config: Arc<rustls::ServerConfig>,
mode: ServerMode,
}
impl TlsServer {
fn new(server: TcpListener, mode: ServerMode, cfg: Arc<rustls::ServerConfig>) -> TlsServer {
TlsServer {
server,
connections: HashMap::new(),
next_id: 2,
tls_config: cfg,
mode,
}
}
fn accept(&mut self, registry: &mio::Registry) -> Result<(), io::Error> {
loop {
match self.server.accept() {
Ok((socket, addr)) => {
debug!("Accepting new connection from {:?}", addr);
let tls_session = rustls::ServerSession::new(&self.tls_config);
let mode = self.mode.clone();
let token = mio::Token(self.next_id);
self.next_id += 1;
let mut connection = Connection::new(socket, token, mode, tls_session);
connection.register(registry);
self.connections.insert(token, connection);
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Ok(()),
Err(err) => {
println!("encountered error while accepting connection; err={:?}", err);
return Err(err);
}
}
}
}
fn conn_event(&mut self, registry: &mio::Registry, event: &mio::event::Event) {
let token = event.token();
if self.connections.contains_key(&token) {
self.connections
.get_mut(&token)
.unwrap()
.ready(registry, event);
if self.connections[&token].is_closed() {
self.connections.remove(&token);
}
}
}
}
/// This is a connection which has been accepted by the server,
/// and is currently being served.
///
/// It has a TCP-level stream, a TLS-level session, and some
/// other state/metadata.
struct Connection {
socket: TcpStream,
token: mio::Token,
closing: bool,
closed: bool,
mode: ServerMode,
tls_session: rustls::ServerSession,
back: Option<TcpStream>,
sent_http_response: bool,
}
/// Open a plaintext TCP-level connection for forwarded connections.
fn open_back(mode: &ServerMode) -> Option<TcpStream> {
match *mode {
ServerMode::Forward(ref port) => {
let addr = net::SocketAddrV4::new(net::Ipv4Addr::new(127, 0, 0, 1), *port);
let conn = TcpStream::connect(net::SocketAddr::V4(addr)).unwrap();
Some(conn)
}
_ => None,
}
}
/// This used to be conveniently exposed by mio: map EWOULDBLOCK
/// errors to something less-errory.
fn try_read(r: io::Result<usize>) -> io::Result<Option<usize>> {
match r {
Ok(len) => Ok(Some(len)),
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
Ok(None)
} else {
Err(e)
}
}
}
}
impl Connection {
fn new(socket: TcpStream,
token: mio::Token,
mode: ServerMode,
tls_session: rustls::ServerSession)
-> Connection {
let back = open_back(&mode);
Connection {
socket,
token,
closing: false,
closed: false,
mode,
tls_session,
back,
sent_http_response: false,
}
}
/// We're a connection, and we have something to do.
fn ready(&mut self, registry: &mio::Registry, ev: &mio::event::Event) {
// If we're readable: read some TLS. Then
// see if that yielded new plaintext. Then
// see if the backend is readable too.
if ev.is_readable() {
self.do_tls_read();
self.try_plain_read();
self.try_back_read();
}
if ev.is_writable() {
self.do_tls_write_and_handle_error();
}
if self.closing {
let _ = self.socket.shutdown(net::Shutdown::Both);
self.close_back();
self.closed = true;
self.deregister(registry);
} else {
self.reregister(registry);
}
}
/// Close the backend connection for forwarded sessions.
fn close_back(&mut self) {
if self.back.is_some() {
let back = self.back.as_mut().unwrap();
back.shutdown(net::Shutdown::Both).unwrap();
}
self.back = None;
}
fn do_tls_read(&mut self) {
// Read some TLS data.
let rc = self.tls_session.read_tls(&mut self.socket);
if rc.is_err() {
let err = rc.unwrap_err();
if let io::ErrorKind::WouldBlock = err.kind() {
return;
}
error!("read error {:?}", err);
self.closing = true;
return;
}
if rc.unwrap() == 0 {
debug!("eof");
self.closing = true;
return;
}
// Process newly-received TLS messages.
let processed = self.tls_session.process_new_packets();
if processed.is_err() {
error!("cannot process packet: {:?}", processed);
// last gasp write to send any alerts
self.do_tls_write_and_handle_error();
self.closing = true;
return;
}
}
fn try_plain_read(&mut self) {
// Read and process all available plaintext.
let mut buf = Vec::new();
let rc = self.tls_session.read_to_end(&mut buf);
if rc.is_err() {
error!("plaintext read failed: {:?}", rc);
self.closing = true;
return;
}
if !buf.is_empty() {
debug!("plaintext read {:?}", buf.len());
self.incoming_plaintext(&buf);
}
}
fn try_back_read(&mut self) {
if self.back.is_none() {
return;
}
// Try a non-blocking read.
let mut buf = [0u8; 1024];
let back = self.back.as_mut().unwrap();
let rc = try_read(back.read(&mut buf));
if rc.is_err() {
error!("backend read failed: {:?}", rc);
self.closing = true;
return;
}
let maybe_len = rc.unwrap();
// If we have a successful but empty read, that's an EOF.
// Otherwise, we shove the data into the TLS session.
match maybe_len {
Some(len) if len == 0 => {
debug!("back eof");
self.closing = true;
}
Some(len) => {
self.tls_session.write_all(&buf[..len]).unwrap();
}
None => {}
};
}
/// Process some amount of received plaintext.
fn incoming_plaintext(&mut self, buf: &[u8]) {
match self.mode {
ServerMode::Echo => {
self.tls_session.write_all(buf).unwrap();
}
ServerMode::Http => {
self.send_http_response_once();
}
ServerMode::Forward(_) => {
self.back.as_mut().unwrap().write_all(buf).unwrap();
}
}
}
fn send_http_response_once(&mut self) {
let response = b"HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello world from rustls tlsserver\r\n";
if !self.sent_http_response {
self.tls_session
.write_all(response)
.unwrap();
self.sent_http_response = true;
self.tls_session.send_close_notify();
}
}
fn tls_write(&mut self) -> io::Result<usize> {
self.tls_session.write_tls(&mut self.socket)
}
fn do_tls_write_and_handle_error(&mut self) {
let rc = self.tls_write();
if rc.is_err() {
error!("write failed {:?}", rc);
self.closing = true;
return;
}
}
fn register(&mut self, registry: &mio::Registry) {
let event_set = self.event_set();
registry.register(&mut self.socket,
self.token,
event_set)
.unwrap();
if self.back.is_some() {
registry.register(self.back.as_mut().unwrap(),
self.token,
mio::Interest::READABLE)
.unwrap();
}
}
fn reregister(&mut self, registry: &mio::Registry) {
let event_set = self.event_set();
registry.reregister(&mut self.socket,
self.token,
event_set)
.unwrap();
}
fn deregister(&mut self, registry: &mio::Registry) {
registry.deregister(&mut self.socket)
.unwrap();
if self.back.is_some() {
registry.deregister(self.back.as_mut().unwrap())
.unwrap();
}
}
/// What IO events we're currently waiting for,
/// based on wants_read/wants_write.
fn event_set(&self) -> mio::Interest {
let rd = self.tls_session.wants_read();
let wr = self.tls_session.wants_write();
if rd && wr {
mio::Interest::READABLE | mio::Interest::WRITABLE
} else if wr {
mio::Interest::WRITABLE
} else {
mio::Interest::READABLE
}
}
fn is_closed(&self) -> bool {
self.closed
}
}
const USAGE: &'static str =
"
Runs a TLS server on :PORT. The default PORT is 443.
`echo' mode means the server echoes received data on each connection.
`http' mode means the server blindly sends a HTTP response on each
connection.
`forward' means the server forwards plaintext to a connection made to
localhost:fport.
`--certs' names the full certificate chain, `--key' provides the
RSA private key.
Usage:
tlsserver --certs CERTFILE --key KEYFILE [--suite SUITE ...] \
[--proto PROTO ...] [options] echo
tlsserver --certs CERTFILE --key KEYFILE [--suite SUITE ...] \
[--proto PROTO ...] [options] http
tlsserver --certs CERTFILE --key KEYFILE [--suite SUITE ...] \
[--proto PROTO ...] [options] forward <fport>
tlsserver (--version | -v)
tlsserver (--help | -h)
Options:
-p, --port PORT Listen on PORT [default: 443].
--certs CERTFILE Read server certificates from CERTFILE.
This should contain PEM-format certificates
in the right order (the first certificate should
certify KEYFILE, the last should be a root CA).
--key KEYFILE Read private key from KEYFILE. This should be a RSA
private key or PKCS8-encoded private key, in PEM format.
--ocsp OCSPFILE Read DER-encoded OCSP response from OCSPFILE and staple
to certificate. Optional.
--auth CERTFILE Enable client authentication, and accept certificates
signed by those roots provided in CERTFILE.
--require-auth Send a fatal alert if the client does not complete client
authentication.
--resumption Support session resumption.
--tickets Support tickets.
--protover VERSION Disable default TLS version list, and use
VERSION instead. May be used multiple times.
--suite SUITE Disable default cipher suite list, and use
SUITE instead. May be used multiple times.
--proto PROTOCOL Negotiate PROTOCOL using ALPN.
May be used multiple times.
--verbose Emit log output.
--version, -v Show tool version.
--help, -h Show this screen.
";
#[derive(Debug, Deserialize)]
struct Args {
cmd_echo: bool,
cmd_http: bool,
cmd_forward: bool,
flag_port: Option<u16>,
flag_verbose: bool,
flag_protover: Vec<String>,
flag_suite: Vec<String>,
flag_proto: Vec<String>,
flag_certs: Option<String>,
flag_key: Option<String>,
flag_ocsp: Option<String>,
flag_auth: Option<String>,
flag_require_auth: bool,
flag_resumption: bool,
flag_tickets: bool,
arg_fport: Option<u16>,
}
fn find_suite(name: &str) -> Option<&'static rustls::SupportedCipherSuite> {
for suite in &rustls::ALL_CIPHERSUITES {
let sname = format!("{:?}", suite.suite).to_lowercase();
if sname == name.to_string().to_lowercase() {
return Some(suite);
}
}
None
}
fn lookup_suites(suites: &[String]) -> Vec<&'static rustls::SupportedCipherSuite> {
let mut out = Vec::new();
for csname in suites {
let scs = find_suite(csname);
match scs {
Some(s) => out.push(s),
None => panic!("cannot look up ciphersuite '{}'", csname),
}
}
out
}
/// Make a vector of protocol versions named in `versions`
fn lookup_versions(versions: &[String]) -> Vec<rustls::ProtocolVersion> {
let mut out = Vec::new();
for vname in versions {
let version = match vname.as_ref() {
"1.2" => rustls::ProtocolVersion::TLSv1_2,
"1.3" => rustls::ProtocolVersion::TLSv1_3,
_ => panic!("cannot look up version '{}', valid are '1.2' and '1.3'", vname),
};
out.push(version);
}
out
}
fn load_certs(filename: &str) -> Vec<rustls::Certificate> {
let certfile = fs::File::open(filename).expect("cannot open certificate file");
let mut reader = BufReader::new(certfile);
rustls::internal::pemfile::certs(&mut reader).unwrap()
}
fn load_private_key(filename: &str) -> rustls::PrivateKey {
let rsa_keys = {
let keyfile = fs::File::open(filename)
.expect("cannot open private key file");
let mut reader = BufReader::new(keyfile);
rustls::internal::pemfile::rsa_private_keys(&mut reader)
.expect("file contains invalid rsa private key")
};
let pkcs8_keys = {
let keyfile = fs::File::open(filename)
.expect("cannot open private key file");
let mut reader = BufReader::new(keyfile);
rustls::internal::pemfile::pkcs8_private_keys(&mut reader)
.expect("file contains invalid pkcs8 private key (encrypted keys not supported)")
};
// prefer to load pkcs8 keys
if !pkcs8_keys.is_empty() {
pkcs8_keys[0].clone()
} else {
assert!(!rsa_keys.is_empty());
rsa_keys[0].clone()
}
}
fn load_ocsp(filename: &Option<String>) -> Vec<u8> {
let mut ret = Vec::new();
if let &Some(ref name) = filename {
fs::File::open(name)
.expect("cannot open ocsp file")
.read_to_end(&mut ret)
.unwrap();
}
ret
}
fn make_config(args: &Args) -> Arc<rustls::ServerConfig> {
let client_auth = if args.flag_auth.is_some() {
let roots = load_certs(args.flag_auth.as_ref().unwrap());
let mut client_auth_roots = RootCertStore::empty();
for root in roots {
client_auth_roots.add(&root).unwrap();
}
if args.flag_require_auth {
AllowAnyAuthenticatedClient::new(client_auth_roots)
} else {
AllowAnyAnonymousOrAuthenticatedClient::new(client_auth_roots)
}
} else {
NoClientAuth::new()
};
let mut config = rustls::ServerConfig::new(client_auth);
config.key_log = Arc::new(rustls::KeyLogFile::new());
let certs = load_certs(args.flag_certs.as_ref().expect("--certs option missing"));
let privkey = load_private_key(args.flag_key.as_ref().expect("--key option missing"));
let ocsp = load_ocsp(&args.flag_ocsp);
config.set_single_cert_with_ocsp_and_sct(certs, privkey, ocsp, vec![])
.expect("bad certificates/private key");
if !args.flag_suite.is_empty() {
config.ciphersuites = lookup_suites(&args.flag_suite);
}
if !args.flag_protover.is_empty() {
config.versions = lookup_versions(&args.flag_protover);
}
if args.flag_resumption {
config.set_persistence(rustls::ServerSessionMemoryCache::new(256));
}
if args.flag_tickets {
config.ticketer = rustls::Ticketer::new();
}
config.set_protocols(&args.flag_proto
.iter()
.map(|proto| proto.as_bytes().to_vec())
.collect::<Vec<_>>()[..]);
Arc::new(config)
}
fn main() {
let version = env!("CARGO_PKG_NAME").to_string() + ", version: " + env!("CARGO_PKG_VERSION");
let args: Args = Docopt::new(USAGE)
.and_then(|d| Ok(d.help(true)))
.and_then(|d| Ok(d.version(Some(version))))
.and_then(|d| d.deserialize())
.unwrap_or_else(|e| e.exit());
if args.flag_verbose {
env_logger::Builder::new()
.parse_filters("trace")
.init();
}
let mut addr: net::SocketAddr = "0.0.0.0:443".parse().unwrap();
addr.set_port(args.flag_port.unwrap_or(443));
let config = make_config(&args);
let mut listener = TcpListener::bind(addr).expect("cannot listen on port");
let mut poll = mio::Poll::new()
.unwrap();
poll.registry().register(&mut listener,
LISTENER,
mio::Interest::READABLE)
.unwrap();
let mode = if args.cmd_echo {
ServerMode::Echo
} else if args.cmd_http {
ServerMode::Http
} else {
ServerMode::Forward(args.arg_fport.expect("fport required"))
};
let mut tlsserv = TlsServer::new(listener, mode, config);
let mut events = mio::Events::with_capacity(256);
loop {
poll.poll(&mut events, None)
.unwrap();
for event in events.iter() {
match event.token() {
LISTENER => {
tlsserv.accept(poll.registry())
.expect("error accepting socket");
}
_ => tlsserv.conn_event(poll.registry(), &event)
}
}
}
}