diff --git a/examples/internal/bogo_shim.rs b/examples/internal/bogo_shim.rs index 2a865628..d8d535a2 100644 --- a/examples/internal/bogo_shim.rs +++ b/examples/internal/bogo_shim.rs @@ -147,7 +147,7 @@ fn make_server_cfg(opts: &Options) -> Arc { if opts.offer_no_client_cas || opts.require_any_client_cert { cfg.client_auth_offer = true; cfg.dangerous() - .set_certificate_verifier(Box::new(NoVerification {})); + .set_certificate_verifier(Arc::new(NoVerification {})); } if opts.require_any_client_cert { diff --git a/examples/tlsclient.rs b/examples/tlsclient.rs index 159b45c3..73316a0a 100644 --- a/examples/tlsclient.rs +++ b/examples/tlsclient.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::process; extern crate mio; @@ -187,7 +187,7 @@ impl TlsClient { /// Note that the contents of such a file are extremely sensitive. /// Don't write this stuff to disk in production code. struct PersistCache { - cache: collections::HashMap, Vec>, + cache: Mutex, Vec>>, filename: Option, } @@ -195,8 +195,8 @@ impl PersistCache { /// Make a new cache. If filename is Some, load the cache /// from it and flush changes back to that file. fn new(filename: &Option) -> PersistCache { - let mut cache = PersistCache { - cache: collections::HashMap::new(), + let cache = PersistCache { + cache: Mutex::new(collections::HashMap::new()), filename: filename.clone(), }; if cache.filename.is_some() { @@ -206,7 +206,7 @@ impl PersistCache { } /// If we have a filename, save the cache contents to it. - fn save(&mut self) { + fn save(&self) { use rustls::internal::msgs::codec::Codec; use rustls::internal::msgs::base::PayloadU16; @@ -214,9 +214,10 @@ impl PersistCache { return; } - let mut file = fs::File::create(self.filename.as_ref().unwrap()).unwrap(); + let mut file = fs::File::create(self.filename.as_ref().unwrap()) + .expect("cannot open cache file"); - for (key, val) in &self.cache { + for (key, val) in self.cache.lock().unwrap().iter() { let mut item = Vec::new(); let key_pl = PayloadU16::new(key.clone()); let val_pl = PayloadU16::new(val.clone()); @@ -227,7 +228,7 @@ impl PersistCache { } /// We have a filename, so replace the cache contents from it. - fn load(&mut self) { + fn load(&self) { use rustls::internal::msgs::codec::{Codec, Reader}; use rustls::internal::msgs::base::PayloadU16; @@ -238,28 +239,34 @@ impl PersistCache { let mut data = Vec::new(); file.read_to_end(&mut data).unwrap(); - self.cache.clear(); + let mut cache = self.cache.lock() + .unwrap(); + cache.clear(); let mut rd = Reader::init(&data); while rd.any_left() { let key_pl = PayloadU16::read(&mut rd).unwrap(); let val_pl = PayloadU16::read(&mut rd).unwrap(); - self.cache.insert(key_pl.0, val_pl.0); + cache.insert(key_pl.0, val_pl.0); } } } impl rustls::StoresClientSessions for PersistCache { /// put: insert into in-memory cache, and perhaps persist to disk. - fn put(&mut self, key: Vec, value: Vec) -> bool { - self.cache.insert(key, value); + fn put(&self, key: Vec, value: Vec) -> bool { + self.cache.lock() + .unwrap() + .insert(key, value); self.save(); true } /// get: from in-memory cache - fn get(&mut self, key: &[u8]) -> Option> { - self.cache.get(key).cloned() + fn get(&self, key: &[u8]) -> Option> { + self.cache.lock() + .unwrap() + .get(key).cloned() } } @@ -435,7 +442,7 @@ fn make_config(args: &Args) -> Arc { config.enable_tickets = false; } - let persist = Box::new(PersistCache::new(&args.flag_cache)); + let persist = Arc::new(PersistCache::new(&args.flag_cache)); config.set_protocols(&args.flag_proto); config.set_persistence(persist); diff --git a/src/client.rs b/src/client.rs index 70c209da..4f0931ca 100644 --- a/src/client.rs +++ b/src/client.rs @@ -27,25 +27,30 @@ use std::fmt; /// Both the keys and values should be treated as /// **highly sensitive data**, containing enough key material /// to break all security of the corresponding session. +/// +/// `put` is a mutating operation; this isn't expressed +/// in the type system to allow implementations freedom in +/// how to achieve interior mutability. `Mutex` is a common +/// choice. pub trait StoresClientSessions : Send + Sync { /// Stores a new `value` for `key`. Returns `true` /// if the value was stored. - fn put(&mut self, key: Vec, value: Vec) -> bool; + fn put(&self, key: Vec, value: Vec) -> bool; /// Returns the latest value for `key`. Returns `None` /// if there's no such value. - fn get(&mut self, key: &[u8]) -> Option>; + fn get(&self, key: &[u8]) -> Option>; } /// An implementor of `StoresClientSessions` which does nothing. struct NoSessionStorage {} impl StoresClientSessions for NoSessionStorage { - fn put(&mut self, _key: Vec, _value: Vec) -> bool { + fn put(&self, _key: Vec, _value: Vec) -> bool { false } - fn get(&mut self, _key: &[u8]) -> Option> { + fn get(&self, _key: &[u8]) -> Option> { None } } @@ -54,38 +59,43 @@ impl StoresClientSessions for NoSessionStorage { /// in memory. It enforces a limit on the number of sessions /// to bound memory usage. pub struct ClientSessionMemoryCache { - cache: collections::HashMap, Vec>, + cache: Mutex, Vec>>, max_entries: usize, } impl ClientSessionMemoryCache { /// Make a new ClientSessionMemoryCache. `size` is the /// maximum number of stored sessions. - pub fn new(size: usize) -> Box { + pub fn new(size: usize) -> Arc { debug_assert!(size > 0); - Box::new(ClientSessionMemoryCache { - cache: collections::HashMap::new(), + Arc::new(ClientSessionMemoryCache { + cache: Mutex::new(collections::HashMap::new()), max_entries: size, }) } - fn limit_size(&mut self) { - while self.cache.len() > self.max_entries { - let k = self.cache.keys().next().unwrap().clone(); - self.cache.remove(&k); + fn limit_size(&self) { + let mut cache = self.cache.lock().unwrap(); + while cache.len() > self.max_entries { + let k = cache.keys().next().unwrap().clone(); + cache.remove(&k); } } } impl StoresClientSessions for ClientSessionMemoryCache { - fn put(&mut self, key: Vec, value: Vec) -> bool { - self.cache.insert(key, value); + fn put(&self, key: Vec, value: Vec) -> bool { + self.cache.lock() + .unwrap() + .insert(key, value); self.limit_size(); true } - fn get(&mut self, key: &[u8]) -> Option> { - self.cache.get(key).cloned() + fn get(&self, key: &[u8]) -> Option> { + self.cache.lock() + .unwrap() + .get(key).cloned() } } @@ -175,7 +185,7 @@ pub struct ClientConfig { pub alpn_protocols: Vec, /// How we store session data or tickets. - pub session_persistence: Arc>>, + pub session_persistence: Arc, /// Our MTU. If None, we don't limit TLS message sizes. pub mtu: Option, @@ -207,7 +217,7 @@ impl ClientConfig { ciphersuites: ALL_CIPHERSUITES.to_vec(), root_store: anchors::RootCertStore::empty(), alpn_protocols: Vec::new(), - session_persistence: Arc::new(Mutex::new(Box::new(NoSessionStorage {}))), + session_persistence: Arc::new(NoSessionStorage {}), mtu: None, client_auth_cert_resolver: Arc::new(FailResolveClientCert {}), enable_tickets: true, @@ -231,8 +241,8 @@ impl ClientConfig { } /// Sets persistence layer to `persist`. - pub fn set_persistence(&mut self, persist: Box) { - self.session_persistence = Arc::new(Mutex::new(persist)); + pub fn set_persistence(&mut self, persist: Arc) { + self.session_persistence = persist; } /// Sets MTU to `mtu`. If None, the default is used. diff --git a/src/client_hs.rs b/src/client_hs.rs index caa5c615..6c948f7f 100644 --- a/src/client_hs.rs +++ b/src/client_hs.rs @@ -81,8 +81,7 @@ fn find_session(sess: &mut ClientSessionImpl) -> Option Option { let key = persist::ClientSessionKey::hint_for_dns_name(&sess.handshake_data.dns_name); let key_buf = key.get_encoding(); - let mut persist = sess.config.session_persistence.lock().unwrap(); - let maybe_value = persist.get(&key_buf); + let maybe_value = sess.config.session_persistence.get(&key_buf); maybe_value.and_then(|enc| NamedGroup::read_bytes(&enc)) } fn save_kx_hint(sess: &mut ClientSessionImpl, group: NamedGroup) { let key = persist::ClientSessionKey::hint_for_dns_name(&sess.handshake_data.dns_name); - let mut persist = sess.config.session_persistence.lock().unwrap(); - persist.put(key.get_encoding(), group.get_encoding()); + sess.config.session_persistence.put(key.get_encoding(), group.get_encoding()); } /// If we have a ticket, we use the sessionid as a signal that we're @@ -1273,8 +1270,8 @@ fn save_session(sess: &mut ClientSessionImpl) { value.set_extended_ms_used(); } - let mut persist = sess.config.session_persistence.lock().unwrap(); - let worked = persist.put(key.get_encoding(), value.get_encoding()); + let worked = sess.config.session_persistence.put(key.get_encoding(), + value.get_encoding()); if worked { info!("Session saved"); @@ -1530,8 +1527,8 @@ fn handle_new_ticket_tls13(sess: &mut ClientSessionImpl, m: Message) -> Result<( let key = persist::ClientSessionKey::session_for_dns_name(&sess.handshake_data.dns_name); - let mut persist = sess.config.session_persistence.lock().unwrap(); - let worked = persist.put(key.get_encoding(), value.get_encoding()); + let worked = sess.config.session_persistence.put(key.get_encoding(), + value.get_encoding()); if worked { info!("Ticket saved"); diff --git a/src/server.rs b/src/server.rs index 7e6c9764..bd5240b4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -25,6 +25,11 @@ use std::fmt; /// Both the keys and values should be treated as /// **highly sensitive data**, containing enough key material /// to break all security of the corresponding session. +/// +/// `put` and `del` are mutating operations; this isn't expressed +/// in the type system to allow implementations freedom in +/// how to achieve interior mutability. `Mutex` is a common +/// choice. pub trait StoresServerSessions : Send + Sync { /// Generate a session ID. fn generate(&self) -> SessionID; @@ -32,7 +37,7 @@ pub trait StoresServerSessions : Send + Sync { /// Store session secrets encoded in `value` against key `id`, /// overwrites any existing value against `id`. Returns `true` /// if the value was stored. - fn put(&mut self, id: &SessionID, value: Vec) -> bool; + fn put(&self, id: &SessionID, value: Vec) -> bool; /// Find a session with the given `id`. Return it, or None /// if it doesn't exist. @@ -40,7 +45,7 @@ pub trait StoresServerSessions : Send + Sync { /// Erase a session with the given `id`. Return true if /// `id` existed and was removed. - fn del(&mut self, id: &SessionID) -> bool; + fn del(&self, id: &SessionID) -> bool; } /// A trait for the ability to encrypt and decrypt tickets. @@ -105,7 +110,7 @@ pub struct ServerConfig { pub ignore_client_order: bool, /// How to store client sessions. - pub session_storage: Arc>>, + pub session_storage: Arc, /// How to produce tickets. pub ticketer: Arc, @@ -142,13 +147,13 @@ impl StoresServerSessions for NoSessionStorage { fn generate(&self) -> SessionID { SessionID::empty() } - fn put(&mut self, _id: &SessionID, _sec: Vec) -> bool { + fn put(&self, _id: &SessionID, _sec: Vec) -> bool { false } fn get(&self, _id: &SessionID) -> Option> { None } - fn del(&mut self, _id: &SessionID) -> bool { + fn del(&self, _id: &SessionID) -> bool { false } } @@ -157,25 +162,26 @@ impl StoresServerSessions for NoSessionStorage { /// in memory. If enforces a limit on the number of stored sessions /// to bound memory usage. pub struct ServerSessionMemoryCache { - cache: collections::HashMap, Vec>, + cache: Mutex, Vec>>, max_entries: usize, } impl ServerSessionMemoryCache { /// Make a new ServerSessionMemoryCache. `size` is the maximum /// number of stored sessions. - pub fn new(size: usize) -> Box { + pub fn new(size: usize) -> Arc { debug_assert!(size > 0); - Box::new(ServerSessionMemoryCache { - cache: collections::HashMap::new(), + Arc::new(ServerSessionMemoryCache { + cache: Mutex::new(collections::HashMap::new()), max_entries: size, }) } - fn limit_size(&mut self) { - while self.cache.len() > self.max_entries { - let k = self.cache.keys().next().unwrap().clone(); - self.cache.remove(&k); + fn limit_size(&self) { + let mut cache = self.cache.lock().unwrap(); + while cache.len() > self.max_entries { + let k = cache.keys().next().unwrap().clone(); + cache.remove(&k); } } } @@ -187,18 +193,24 @@ impl StoresServerSessions for ServerSessionMemoryCache { SessionID::new(&v) } - fn put(&mut self, id: &SessionID, sec: Vec) -> bool { - self.cache.insert(id.get_encoding(), sec); + fn put(&self, id: &SessionID, sec: Vec) -> bool { + self.cache.lock() + .unwrap() + .insert(id.get_encoding(), sec); self.limit_size(); true } fn get(&self, id: &SessionID) -> Option> { - self.cache.get(&id.get_encoding()).cloned() + self.cache.lock() + .unwrap() + .get(&id.get_encoding()).cloned() } - fn del(&mut self, id: &SessionID) -> bool { - self.cache.remove(&id.get_encoding()).is_some() + fn del(&self, id: &SessionID) -> bool { + self.cache.lock() + .unwrap() + .remove(&id.get_encoding()).is_some() } } @@ -266,7 +278,7 @@ impl ServerConfig { ServerConfig { ciphersuites: ALL_CIPHERSUITES.to_vec(), ignore_client_order: false, - session_storage: Arc::new(Mutex::new(Box::new(NoSessionStorage {}))), + session_storage: Arc::new(NoSessionStorage {}), ticketer: Arc::new(NeverProducesTickets {}), alpn_protocols: Vec::new(), cert_resolver: Arc::new(FailResolveChain {}), @@ -284,8 +296,8 @@ impl ServerConfig { } /// Sets the session persistence layer to `persist`. - pub fn set_persistence(&mut self, persist: Box) { - self.session_storage = Arc::new(Mutex::new(persist)); + pub fn set_persistence(&mut self, persist: Arc) { + self.session_storage = persist; } /// Sets a single certificate chain and matching private key. This @@ -337,6 +349,7 @@ impl ServerConfig { pub mod danger { use super::ServerConfig; use super::verify::ClientCertVerifier; + use super::Arc; /// Accessor for dangerous configuration options. pub struct DangerousServerConfig<'a> { @@ -347,7 +360,7 @@ pub mod danger { impl<'a> DangerousServerConfig<'a> { /// Overrides the default `ClientCertVerifier` with something else. pub fn set_certificate_verifier(&mut self, - verifier: Box) { + verifier: Arc) { self.cfg.verifier = verifier; } } diff --git a/src/server_hs.rs b/src/server_hs.rs index 602731fc..61b4b741 100644 --- a/src/server_hs.rs +++ b/src/server_hs.rs @@ -125,8 +125,6 @@ fn emit_server_hello(sess: &mut ServerSessionImpl, if sess.handshake_data.session_id.is_empty() { let sessid = sess.config .session_storage - .lock() - .unwrap() .generate(); sess.handshake_data.session_id = sessid; } @@ -845,10 +843,8 @@ fn handle_client_hello(sess: &mut ServerSessionImpl, m: Message) -> StateResult // Perhaps resume? If we received a ticket, the sessionid // does not correspond to a real session. if !client_hello.session_id.is_empty() && !ticket_received { - let maybe_resume = { - let persist = sess.config.session_storage.lock().unwrap(); - persist.get(&client_hello.session_id) - } + let maybe_resume = sess.config.session_storage + .get(&client_hello.session_id) .and_then(|x| persist::ServerSessionValue::read_bytes(&x)); if can_resume(sess, &maybe_resume) { @@ -1197,8 +1193,9 @@ fn handle_finished(sess: &mut ServerSessionImpl, m: Message) -> StateResult { if !sess.handshake_data.doing_resume && !sess.handshake_data.session_id.is_empty() { let value = get_server_session_value(sess); - let mut persist = sess.config.session_storage.lock().unwrap(); - if persist.put(&sess.handshake_data.session_id, value.get_encoding()) { + let worked = sess.config.session_storage + .put(&sess.handshake_data.session_id, value.get_encoding()); + if worked { info!("Session saved"); } else { info!("Session not saved");