Require interior mutability from persistence impls

This commit is contained in:
Joseph Birr-Pixton 2017-06-17 12:05:25 +01:00
parent b5de72ecd4
commit 3d874b17aa
6 changed files with 100 additions and 76 deletions

View File

@ -147,7 +147,7 @@ fn make_server_cfg(opts: &Options) -> Arc<rustls::ServerConfig> {
if opts.offer_no_client_cas || opts.require_any_client_cert {
cfg.client_auth_offer = true;
cfg.dangerous()
.set_certificate_verifier(Box::new(NoVerification {}));
.set_certificate_verifier(Arc::new(NoVerification {}));
}
if opts.require_any_client_cert {

View File

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::process;
extern crate mio;
@ -187,7 +187,7 @@ impl TlsClient {
/// Note that the contents of such a file are extremely sensitive.
/// Don't write this stuff to disk in production code.
struct PersistCache {
cache: collections::HashMap<Vec<u8>, Vec<u8>>,
cache: Mutex<collections::HashMap<Vec<u8>, Vec<u8>>>,
filename: Option<String>,
}
@ -195,8 +195,8 @@ impl PersistCache {
/// Make a new cache. If filename is Some, load the cache
/// from it and flush changes back to that file.
fn new(filename: &Option<String>) -> PersistCache {
let mut cache = PersistCache {
cache: collections::HashMap::new(),
let cache = PersistCache {
cache: Mutex::new(collections::HashMap::new()),
filename: filename.clone(),
};
if cache.filename.is_some() {
@ -206,7 +206,7 @@ impl PersistCache {
}
/// If we have a filename, save the cache contents to it.
fn save(&mut self) {
fn save(&self) {
use rustls::internal::msgs::codec::Codec;
use rustls::internal::msgs::base::PayloadU16;
@ -214,9 +214,10 @@ impl PersistCache {
return;
}
let mut file = fs::File::create(self.filename.as_ref().unwrap()).unwrap();
let mut file = fs::File::create(self.filename.as_ref().unwrap())
.expect("cannot open cache file");
for (key, val) in &self.cache {
for (key, val) in self.cache.lock().unwrap().iter() {
let mut item = Vec::new();
let key_pl = PayloadU16::new(key.clone());
let val_pl = PayloadU16::new(val.clone());
@ -227,7 +228,7 @@ impl PersistCache {
}
/// We have a filename, so replace the cache contents from it.
fn load(&mut self) {
fn load(&self) {
use rustls::internal::msgs::codec::{Codec, Reader};
use rustls::internal::msgs::base::PayloadU16;
@ -238,28 +239,34 @@ impl PersistCache {
let mut data = Vec::new();
file.read_to_end(&mut data).unwrap();
self.cache.clear();
let mut cache = self.cache.lock()
.unwrap();
cache.clear();
let mut rd = Reader::init(&data);
while rd.any_left() {
let key_pl = PayloadU16::read(&mut rd).unwrap();
let val_pl = PayloadU16::read(&mut rd).unwrap();
self.cache.insert(key_pl.0, val_pl.0);
cache.insert(key_pl.0, val_pl.0);
}
}
}
impl rustls::StoresClientSessions for PersistCache {
/// put: insert into in-memory cache, and perhaps persist to disk.
fn put(&mut self, key: Vec<u8>, value: Vec<u8>) -> bool {
self.cache.insert(key, value);
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
self.cache.lock()
.unwrap()
.insert(key, value);
self.save();
true
}
/// get: from in-memory cache
fn get(&mut self, key: &[u8]) -> Option<Vec<u8>> {
self.cache.get(key).cloned()
fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
self.cache.lock()
.unwrap()
.get(key).cloned()
}
}
@ -435,7 +442,7 @@ fn make_config(args: &Args) -> Arc<rustls::ClientConfig> {
config.enable_tickets = false;
}
let persist = Box::new(PersistCache::new(&args.flag_cache));
let persist = Arc::new(PersistCache::new(&args.flag_cache));
config.set_protocols(&args.flag_proto);
config.set_persistence(persist);

View File

@ -27,25 +27,30 @@ use std::fmt;
/// Both the keys and values should be treated as
/// **highly sensitive data**, containing enough key material
/// to break all security of the corresponding session.
///
/// `put` is a mutating operation; this isn't expressed
/// in the type system to allow implementations freedom in
/// how to achieve interior mutability. `Mutex` is a common
/// choice.
pub trait StoresClientSessions : Send + Sync {
/// Stores a new `value` for `key`. Returns `true`
/// if the value was stored.
fn put(&mut self, key: Vec<u8>, value: Vec<u8>) -> bool;
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool;
/// Returns the latest value for `key`. Returns `None`
/// if there's no such value.
fn get(&mut self, key: &[u8]) -> Option<Vec<u8>>;
fn get(&self, key: &[u8]) -> Option<Vec<u8>>;
}
/// An implementor of `StoresClientSessions` which does nothing.
struct NoSessionStorage {}
impl StoresClientSessions for NoSessionStorage {
fn put(&mut self, _key: Vec<u8>, _value: Vec<u8>) -> bool {
fn put(&self, _key: Vec<u8>, _value: Vec<u8>) -> bool {
false
}
fn get(&mut self, _key: &[u8]) -> Option<Vec<u8>> {
fn get(&self, _key: &[u8]) -> Option<Vec<u8>> {
None
}
}
@ -54,38 +59,43 @@ impl StoresClientSessions for NoSessionStorage {
/// in memory. It enforces a limit on the number of sessions
/// to bound memory usage.
pub struct ClientSessionMemoryCache {
cache: collections::HashMap<Vec<u8>, Vec<u8>>,
cache: Mutex<collections::HashMap<Vec<u8>, Vec<u8>>>,
max_entries: usize,
}
impl ClientSessionMemoryCache {
/// Make a new ClientSessionMemoryCache. `size` is the
/// maximum number of stored sessions.
pub fn new(size: usize) -> Box<ClientSessionMemoryCache> {
pub fn new(size: usize) -> Arc<ClientSessionMemoryCache> {
debug_assert!(size > 0);
Box::new(ClientSessionMemoryCache {
cache: collections::HashMap::new(),
Arc::new(ClientSessionMemoryCache {
cache: Mutex::new(collections::HashMap::new()),
max_entries: size,
})
}
fn limit_size(&mut self) {
while self.cache.len() > self.max_entries {
let k = self.cache.keys().next().unwrap().clone();
self.cache.remove(&k);
fn limit_size(&self) {
let mut cache = self.cache.lock().unwrap();
while cache.len() > self.max_entries {
let k = cache.keys().next().unwrap().clone();
cache.remove(&k);
}
}
}
impl StoresClientSessions for ClientSessionMemoryCache {
fn put(&mut self, key: Vec<u8>, value: Vec<u8>) -> bool {
self.cache.insert(key, value);
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
self.cache.lock()
.unwrap()
.insert(key, value);
self.limit_size();
true
}
fn get(&mut self, key: &[u8]) -> Option<Vec<u8>> {
self.cache.get(key).cloned()
fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
self.cache.lock()
.unwrap()
.get(key).cloned()
}
}
@ -175,7 +185,7 @@ pub struct ClientConfig {
pub alpn_protocols: Vec<String>,
/// How we store session data or tickets.
pub session_persistence: Arc<Mutex<Box<StoresClientSessions>>>,
pub session_persistence: Arc<StoresClientSessions>,
/// Our MTU. If None, we don't limit TLS message sizes.
pub mtu: Option<usize>,
@ -207,7 +217,7 @@ impl ClientConfig {
ciphersuites: ALL_CIPHERSUITES.to_vec(),
root_store: anchors::RootCertStore::empty(),
alpn_protocols: Vec::new(),
session_persistence: Arc::new(Mutex::new(Box::new(NoSessionStorage {}))),
session_persistence: Arc::new(NoSessionStorage {}),
mtu: None,
client_auth_cert_resolver: Arc::new(FailResolveClientCert {}),
enable_tickets: true,
@ -231,8 +241,8 @@ impl ClientConfig {
}
/// Sets persistence layer to `persist`.
pub fn set_persistence(&mut self, persist: Box<StoresClientSessions>) {
self.session_persistence = Arc::new(Mutex::new(persist));
pub fn set_persistence(&mut self, persist: Arc<StoresClientSessions>) {
self.session_persistence = persist;
}
/// Sets MTU to `mtu`. If None, the default is used.

View File

@ -81,8 +81,7 @@ fn find_session(sess: &mut ClientSessionImpl) -> Option<persist::ClientSessionVa
let key = persist::ClientSessionKey::session_for_dns_name(&sess.handshake_data.dns_name);
let key_buf = key.get_encoding();
let mut persist = sess.config.session_persistence.lock().unwrap();
let maybe_value = persist.get(&key_buf);
let maybe_value = sess.config.session_persistence.get(&key_buf);
if maybe_value.is_none() {
info!("No cached session for {:?}", sess.handshake_data.dns_name);
@ -105,16 +104,14 @@ fn find_kx_hint(sess: &mut ClientSessionImpl) -> Option<NamedGroup> {
let key = persist::ClientSessionKey::hint_for_dns_name(&sess.handshake_data.dns_name);
let key_buf = key.get_encoding();
let mut persist = sess.config.session_persistence.lock().unwrap();
let maybe_value = persist.get(&key_buf);
let maybe_value = sess.config.session_persistence.get(&key_buf);
maybe_value.and_then(|enc| NamedGroup::read_bytes(&enc))
}
fn save_kx_hint(sess: &mut ClientSessionImpl, group: NamedGroup) {
let key = persist::ClientSessionKey::hint_for_dns_name(&sess.handshake_data.dns_name);
let mut persist = sess.config.session_persistence.lock().unwrap();
persist.put(key.get_encoding(), group.get_encoding());
sess.config.session_persistence.put(key.get_encoding(), group.get_encoding());
}
/// If we have a ticket, we use the sessionid as a signal that we're
@ -1273,8 +1270,8 @@ fn save_session(sess: &mut ClientSessionImpl) {
value.set_extended_ms_used();
}
let mut persist = sess.config.session_persistence.lock().unwrap();
let worked = persist.put(key.get_encoding(), value.get_encoding());
let worked = sess.config.session_persistence.put(key.get_encoding(),
value.get_encoding());
if worked {
info!("Session saved");
@ -1530,8 +1527,8 @@ fn handle_new_ticket_tls13(sess: &mut ClientSessionImpl, m: Message) -> Result<(
let key = persist::ClientSessionKey::session_for_dns_name(&sess.handshake_data.dns_name);
let mut persist = sess.config.session_persistence.lock().unwrap();
let worked = persist.put(key.get_encoding(), value.get_encoding());
let worked = sess.config.session_persistence.put(key.get_encoding(),
value.get_encoding());
if worked {
info!("Ticket saved");

View File

@ -25,6 +25,11 @@ use std::fmt;
/// Both the keys and values should be treated as
/// **highly sensitive data**, containing enough key material
/// to break all security of the corresponding session.
///
/// `put` and `del` are mutating operations; this isn't expressed
/// in the type system to allow implementations freedom in
/// how to achieve interior mutability. `Mutex` is a common
/// choice.
pub trait StoresServerSessions : Send + Sync {
/// Generate a session ID.
fn generate(&self) -> SessionID;
@ -32,7 +37,7 @@ pub trait StoresServerSessions : Send + Sync {
/// Store session secrets encoded in `value` against key `id`,
/// overwrites any existing value against `id`. Returns `true`
/// if the value was stored.
fn put(&mut self, id: &SessionID, value: Vec<u8>) -> bool;
fn put(&self, id: &SessionID, value: Vec<u8>) -> bool;
/// Find a session with the given `id`. Return it, or None
/// if it doesn't exist.
@ -40,7 +45,7 @@ pub trait StoresServerSessions : Send + Sync {
/// Erase a session with the given `id`. Return true if
/// `id` existed and was removed.
fn del(&mut self, id: &SessionID) -> bool;
fn del(&self, id: &SessionID) -> bool;
}
/// A trait for the ability to encrypt and decrypt tickets.
@ -105,7 +110,7 @@ pub struct ServerConfig {
pub ignore_client_order: bool,
/// How to store client sessions.
pub session_storage: Arc<Mutex<Box<StoresServerSessions + Send>>>,
pub session_storage: Arc<StoresServerSessions + Send>,
/// How to produce tickets.
pub ticketer: Arc<ProducesTickets>,
@ -142,13 +147,13 @@ impl StoresServerSessions for NoSessionStorage {
fn generate(&self) -> SessionID {
SessionID::empty()
}
fn put(&mut self, _id: &SessionID, _sec: Vec<u8>) -> bool {
fn put(&self, _id: &SessionID, _sec: Vec<u8>) -> bool {
false
}
fn get(&self, _id: &SessionID) -> Option<Vec<u8>> {
None
}
fn del(&mut self, _id: &SessionID) -> bool {
fn del(&self, _id: &SessionID) -> bool {
false
}
}
@ -157,25 +162,26 @@ impl StoresServerSessions for NoSessionStorage {
/// in memory. If enforces a limit on the number of stored sessions
/// to bound memory usage.
pub struct ServerSessionMemoryCache {
cache: collections::HashMap<Vec<u8>, Vec<u8>>,
cache: Mutex<collections::HashMap<Vec<u8>, Vec<u8>>>,
max_entries: usize,
}
impl ServerSessionMemoryCache {
/// Make a new ServerSessionMemoryCache. `size` is the maximum
/// number of stored sessions.
pub fn new(size: usize) -> Box<ServerSessionMemoryCache> {
pub fn new(size: usize) -> Arc<ServerSessionMemoryCache> {
debug_assert!(size > 0);
Box::new(ServerSessionMemoryCache {
cache: collections::HashMap::new(),
Arc::new(ServerSessionMemoryCache {
cache: Mutex::new(collections::HashMap::new()),
max_entries: size,
})
}
fn limit_size(&mut self) {
while self.cache.len() > self.max_entries {
let k = self.cache.keys().next().unwrap().clone();
self.cache.remove(&k);
fn limit_size(&self) {
let mut cache = self.cache.lock().unwrap();
while cache.len() > self.max_entries {
let k = cache.keys().next().unwrap().clone();
cache.remove(&k);
}
}
}
@ -187,18 +193,24 @@ impl StoresServerSessions for ServerSessionMemoryCache {
SessionID::new(&v)
}
fn put(&mut self, id: &SessionID, sec: Vec<u8>) -> bool {
self.cache.insert(id.get_encoding(), sec);
fn put(&self, id: &SessionID, sec: Vec<u8>) -> bool {
self.cache.lock()
.unwrap()
.insert(id.get_encoding(), sec);
self.limit_size();
true
}
fn get(&self, id: &SessionID) -> Option<Vec<u8>> {
self.cache.get(&id.get_encoding()).cloned()
self.cache.lock()
.unwrap()
.get(&id.get_encoding()).cloned()
}
fn del(&mut self, id: &SessionID) -> bool {
self.cache.remove(&id.get_encoding()).is_some()
fn del(&self, id: &SessionID) -> bool {
self.cache.lock()
.unwrap()
.remove(&id.get_encoding()).is_some()
}
}
@ -266,7 +278,7 @@ impl ServerConfig {
ServerConfig {
ciphersuites: ALL_CIPHERSUITES.to_vec(),
ignore_client_order: false,
session_storage: Arc::new(Mutex::new(Box::new(NoSessionStorage {}))),
session_storage: Arc::new(NoSessionStorage {}),
ticketer: Arc::new(NeverProducesTickets {}),
alpn_protocols: Vec::new(),
cert_resolver: Arc::new(FailResolveChain {}),
@ -284,8 +296,8 @@ impl ServerConfig {
}
/// Sets the session persistence layer to `persist`.
pub fn set_persistence(&mut self, persist: Box<StoresServerSessions + Send>) {
self.session_storage = Arc::new(Mutex::new(persist));
pub fn set_persistence(&mut self, persist: Arc<StoresServerSessions + Send>) {
self.session_storage = persist;
}
/// Sets a single certificate chain and matching private key. This
@ -337,6 +349,7 @@ impl ServerConfig {
pub mod danger {
use super::ServerConfig;
use super::verify::ClientCertVerifier;
use super::Arc;
/// Accessor for dangerous configuration options.
pub struct DangerousServerConfig<'a> {
@ -347,7 +360,7 @@ pub mod danger {
impl<'a> DangerousServerConfig<'a> {
/// Overrides the default `ClientCertVerifier` with something else.
pub fn set_certificate_verifier(&mut self,
verifier: Box<ClientCertVerifier>) {
verifier: Arc<ClientCertVerifier>) {
self.cfg.verifier = verifier;
}
}

View File

@ -125,8 +125,6 @@ fn emit_server_hello(sess: &mut ServerSessionImpl,
if sess.handshake_data.session_id.is_empty() {
let sessid = sess.config
.session_storage
.lock()
.unwrap()
.generate();
sess.handshake_data.session_id = sessid;
}
@ -845,10 +843,8 @@ fn handle_client_hello(sess: &mut ServerSessionImpl, m: Message) -> StateResult
// Perhaps resume? If we received a ticket, the sessionid
// does not correspond to a real session.
if !client_hello.session_id.is_empty() && !ticket_received {
let maybe_resume = {
let persist = sess.config.session_storage.lock().unwrap();
persist.get(&client_hello.session_id)
}
let maybe_resume = sess.config.session_storage
.get(&client_hello.session_id)
.and_then(|x| persist::ServerSessionValue::read_bytes(&x));
if can_resume(sess, &maybe_resume) {
@ -1197,8 +1193,9 @@ fn handle_finished(sess: &mut ServerSessionImpl, m: Message) -> StateResult {
if !sess.handshake_data.doing_resume && !sess.handshake_data.session_id.is_empty() {
let value = get_server_session_value(sess);
let mut persist = sess.config.session_storage.lock().unwrap();
if persist.put(&sess.handshake_data.session_id, value.get_encoding()) {
let worked = sess.config.session_storage
.put(&sess.handshake_data.session_id, value.get_encoding());
if worked {
info!("Session saved");
} else {
info!("Session not saved");