diff --git a/fuzz/fuzzers/persist.rs b/fuzz/fuzzers/persist.rs index c740863c..ff73d499 100644 --- a/fuzz/fuzzers/persist.rs +++ b/fuzz/fuzzers/persist.rs @@ -10,22 +10,6 @@ fn try_type(data: &[u8]) where T: Codec { T::read(&mut rdr); } -fn try_tls12clientsession(data: &[u8]) { - let mut rdr = Reader::init(data); - persist::ClientSessionValue::read(&mut rdr, - rustls::CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - &rustls::ALL_CIPHER_SUITES); -} - -fn try_tls13clientsession(data: &[u8]) { - let mut rdr = Reader::init(data); - persist::ClientSessionValue::read(&mut rdr, - rustls::CipherSuite::TLS13_AES_256_GCM_SHA384, - &rustls::ALL_CIPHER_SUITES); -} - fuzz_target!(|data: &[u8]| { - try_tls12clientsession(data); - try_tls13clientsession(data); try_type::(data); }); diff --git a/rustls/examples/internal/bogo_shim.rs b/rustls/examples/internal/bogo_shim.rs index ba9c0e16..71bab79c 100644 --- a/rustls/examples/internal/bogo_shim.rs +++ b/rustls/examples/internal/bogo_shim.rs @@ -8,11 +8,11 @@ use base64::prelude::{Engine, BASE64_STANDARD}; use env_logger; use rustls; -use rustls::internal::msgs::codec::{Codec, Reader}; +use rustls::internal::msgs::codec::Codec; use rustls::internal::msgs::persist; use rustls::quic::{self, ClientQuicExt, QuicExt, ServerQuicExt}; use rustls::server::ClientHello; -use rustls::{CipherSuite, ProtocolVersion}; +use rustls::ProtocolVersion; use rustls::{ClientConnection, Connection, ServerConnection, Side}; use std::env; @@ -478,34 +478,45 @@ impl ClientCacheWithoutKxHints { } impl rustls::client::StoresClientSessions for ClientCacheWithoutKxHints { - fn put(&self, key: Vec, value: Vec) -> bool { - if key.len() > 2 && key[0] == b'k' && key[1] == b'x' { - return true; - } - - let mut reader = Reader::init(&value[2..]); - let csv = CipherSuite::read_bytes(&value[..2]) - .and_then(|suite| { - persist::ClientSessionValue::read(&mut reader, suite, &rustls::ALL_CIPHER_SUITES) - }) - .unwrap(); - - let value = match csv { - persist::ClientSessionValue::Tls13(mut tls13) => { - tls13.common.rewind_epoch(self.delay); - tls13.get_encoding() - } - persist::ClientSessionValue::Tls12(mut tls12) => { - tls12.common.rewind_epoch(self.delay); - tls12.get_encoding() - } - }; - - self.storage.put(key, value) + fn put_kx_hint(&self, _: &rustls::ServerName, _: rustls::NamedGroup) {} + fn get_kx_hint(&self, _: &rustls::ServerName) -> Option { + None } - fn get(&self, key: &[u8]) -> Option> { - self.storage.get(key) + fn put_tls12_session( + &self, + server_name: &rustls::ServerName, + mut value: rustls::client::Tls12ClientSessionValue, + ) { + value.common.rewind_epoch(self.delay); + self.storage + .put_tls12_session(server_name, value); + } + + fn get_tls12_session( + &self, + server_name: &rustls::ServerName, + ) -> Option { + self.storage + .get_tls12_session(server_name) + } + + fn add_tls13_ticket( + &self, + server_name: &rustls::ServerName, + mut value: rustls::client::Tls13ClientSessionValue, + ) { + value.common.rewind_epoch(self.delay); + self.storage + .add_tls13_ticket(server_name, value); + } + + fn take_tls13_ticket( + &self, + server_name: &rustls::ServerName, + ) -> Option { + self.storage + .take_tls13_ticket(server_name) } } diff --git a/rustls/src/client/client_conn.rs b/rustls/src/client/client_conn.rs index b87cf335..0f8db5ce 100644 --- a/rustls/src/client/client_conn.rs +++ b/rustls/src/client/client_conn.rs @@ -5,7 +5,7 @@ use crate::error::Error; use crate::kx::SupportedKxGroup; #[cfg(feature = "logging")] use crate::log::trace; -use crate::msgs::codec::Codec; + #[cfg(feature = "quic")] use crate::msgs::enums::AlertDescription; use crate::msgs::enums::NamedGroup; @@ -30,81 +30,52 @@ use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::{fmt, io, mem}; -/// A trait for the ability to store client session data. -/// The keys and values are opaque. +/// A trait for the ability to store client session data, so that future +/// sessions can be resumed. /// /// Both the keys and values should be treated as /// **highly sensitive data**, containing enough key material /// to break all security of the corresponding session. /// -/// `put` is a mutating operation; this isn't expressed -/// in the type system to allow implementations freedom in -/// how to achieve interior mutability. `Mutex` is a common -/// choice. +/// `put_`, `add_` and `take_` operations are mutating; this isn't +/// expressed in the type system to allow implementations freedom in +/// how to achieve interior mutability. `Mutex` is a common choice. pub trait StoresClientSessions: Send + Sync { - /// Stores a new `value` for `key`. Returns `true` - /// if the value was stored. - fn put(&self, key: Vec, value: Vec) -> bool; - - /// Returns the latest value for `key`. Returns `None` - /// if there's no such value. - fn get(&self, key: &[u8]) -> Option>; - - /// Provide a best-effort guess for which `NamedGroup` the given server - /// might prefer. If `None` is returned, the caller chooses the first - /// configured group. - fn get_kx_hint_for_server(&self, server_name: &ServerName) -> Option { - let key = persist::ClientSessionKey::hint_for_server_name(server_name); - let key_buf = key.get_encoding(); - - self.get(&key_buf) - .and_then(|enc| NamedGroup::read_bytes(&enc)) - } - /// Remember what `NamedGroup` the given server chose. - fn put_kx_hint_for_server(&self, server_name: &ServerName, group: NamedGroup) { - let key = persist::ClientSessionKey::hint_for_server_name(server_name); - self.put(key.get_encoding(), group.get_encoding()); - } + fn put_kx_hint(&self, server_name: &ServerName, group: NamedGroup); - /// Remember a TLS1.2 session (at most one of these can be remembered at a time). - fn put_tls12_session( - &self, - server_name: &ServerName, - value: &persist::Tls12ClientSessionValue, - ) { - let key = persist::ClientSessionKey::session_for_server_name(server_name); - self.put(key.get_encoding(), value.get_encoding()); - } + /// This should return the value most recently passed to `put_kx_hint` + /// for the given `server_name`. + /// + /// If `None` is returned, the caller chooses the first configured group, + /// and an extra round trip might happen if that choice is unsatisfactory + /// to the server. + fn get_kx_hint(&self, server_name: &ServerName) -> Option; + + /// Remember a TLS1.2 session. At most one of these can be remembered at a time, per + /// `server_name`. + fn put_tls12_session(&self, server_name: &ServerName, value: persist::Tls12ClientSessionValue); /// Get the most recently saved TLS1.2 session for `server_name` provided to `put_tls12_session`. - fn get_tls12_session(&self, server_name: &ServerName) -> Option> { - let key = persist::ClientSessionKey::session_for_server_name(server_name); - self.get(&key.get_encoding()) - } + fn get_tls12_session( + &self, + server_name: &ServerName, + ) -> Option; /// Remember a TLS1.3 ticket that might be retrieved later from `take_tls13_ticket`, allowing /// resumption of this session. This can be called multiple times for a given session, allowing /// multiple independent tickets to be valid at once. The number of times this is called /// is controlled by the server, so implementations of this trait should apply a reasonable bound /// of how many items are stored simultaneously. - fn add_tls13_ticket(&self, server_name: &ServerName, value: &persist::Tls13ClientSessionValue) { - let key = persist::ClientSessionKey::session_for_server_name(server_name); - self.put(key.get_encoding(), value.get_encoding()); - } + fn add_tls13_ticket(&self, server_name: &ServerName, value: persist::Tls13ClientSessionValue); /// Return a TLS1.3 ticket previously provided to `add_tls13_ticket`. /// /// Implementations of this trait must return each value provided to `add_tls13_ticket` _at most once_. - fn take_tls13_ticket(&self, server_name: &ServerName) -> Option> { - let key = persist::ClientSessionKey::session_for_server_name(server_name).get_encoding(); - - let value = self.get(&key); - if value.is_some() { - self.put(key, Vec::new()); - } - value - } + fn take_tls13_ticket( + &self, + server_name: &ServerName, + ) -> Option; } /// A trait for the ability to choose a certificate chain and @@ -304,38 +275,6 @@ impl ServerName { Self::IpAddress(_) => None, } } - - /// Return a prefix-free, unique encoding for the name. - pub(crate) fn encode(&self) -> Vec { - enum UniqueTypeCode { - DnsName = 0x01, - IpAddr = 0x02, - } - - match self { - Self::DnsName(dns_name) => { - let bytes = dns_name.0.as_ref(); - - let mut r = Vec::with_capacity(2 + bytes.as_ref().len()); - r.push(UniqueTypeCode::DnsName as u8); - r.push(bytes.as_ref().len() as u8); - r.extend_from_slice(bytes.as_ref()); - - r - } - Self::IpAddress(address) => { - let string = address.to_string(); - let bytes = string.as_bytes(); - - let mut r = Vec::with_capacity(2 + bytes.len()); - r.push(UniqueTypeCode::IpAddr as u8); - r.push(bytes.len() as u8); - r.extend_from_slice(bytes); - - r - } - } - } } /// Attempt to make a ServerName from a string by parsing diff --git a/rustls/src/client/handy.rs b/rustls/src/client/handy.rs index 804887ac..30645a51 100644 --- a/rustls/src/client/handy.rs +++ b/rustls/src/client/handy.rs @@ -3,56 +3,133 @@ use crate::enums::SignatureScheme; use crate::error::Error; use crate::key; use crate::limited_cache; +use crate::msgs::persist; use crate::sign; +use crate::NamedGroup; +use crate::ServerName; +use std::collections::VecDeque; use std::sync::{Arc, Mutex}; /// An implementer of `StoresClientSessions` which does nothing. pub struct NoClientSessionStorage {} impl client::StoresClientSessions for NoClientSessionStorage { - fn put(&self, _key: Vec, _value: Vec) -> bool { - false + fn put_kx_hint(&self, _: &ServerName, _: NamedGroup) {} + + fn get_kx_hint(&self, _: &ServerName) -> Option { + None } - fn get(&self, _key: &[u8]) -> Option> { + fn put_tls12_session(&self, _: &ServerName, _: persist::Tls12ClientSessionValue) {} + + fn get_tls12_session(&self, _: &ServerName) -> Option { None } + + fn add_tls13_ticket(&self, _: &ServerName, _: persist::Tls13ClientSessionValue) {} + + fn take_tls13_ticket(&self, _: &ServerName) -> Option { + None + } +} + +const MAX_TLS13_TICKETS_PER_SERVER: usize = 8; + +struct ServerData { + kx_hint: Option, + + // Zero or one TLS1.2 sessions. + tls12: Option, + + // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first. + tls13: VecDeque, +} + +impl Default for ServerData { + fn default() -> Self { + Self { + kx_hint: None, + tls12: None, + tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER), + } + } } /// An implementer of `StoresClientSessions` that stores everything /// in memory. It enforces a limit on the number of entries /// to bound memory usage. pub struct ClientSessionMemoryCache { - cache: Mutex, Vec>>, + servers: Mutex>, } impl ClientSessionMemoryCache { /// Make a new ClientSessionMemoryCache. `size` is the /// maximum number of stored sessions. pub fn new(size: usize) -> Arc { - debug_assert!(size > 0); + let max_servers = + size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1) / MAX_TLS13_TICKETS_PER_SERVER; Arc::new(Self { - cache: Mutex::new(limited_cache::LimitedCache::new(size)), + servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)), }) } } impl client::StoresClientSessions for ClientSessionMemoryCache { - fn put(&self, key: Vec, value: Vec) -> bool { - self.cache + fn put_kx_hint(&self, server_name: &ServerName, group: NamedGroup) { + self.servers .lock() .unwrap() - .insert(key, value); - true + .get_or_insert_default_and_edit(server_name.clone(), |data| data.kx_hint = Some(group)); } - fn get(&self, key: &[u8]) -> Option> { - self.cache + fn get_kx_hint(&self, server_name: &ServerName) -> Option { + self.servers .lock() .unwrap() - .get(key) - .cloned() + .get(server_name) + .and_then(|sd| sd.kx_hint) + } + + fn put_tls12_session(&self, server_name: &ServerName, value: persist::Tls12ClientSessionValue) { + self.servers + .lock() + .unwrap() + .get_or_insert_default_and_edit(server_name.clone(), |data| data.tls12 = Some(value)); + } + + fn get_tls12_session( + &self, + server_name: &ServerName, + ) -> Option { + self.servers + .lock() + .unwrap() + .get(server_name) + .and_then(|sd| sd.tls12.as_ref().cloned()) + } + + fn add_tls13_ticket(&self, server_name: &ServerName, value: persist::Tls13ClientSessionValue) { + self.servers + .lock() + .unwrap() + .get_or_insert_default_and_edit(server_name.clone(), |data| { + if data.tls13.len() == data.tls13.capacity() { + data.tls13.pop_front(); + } + data.tls13.push_back(value); + }) + } + + fn take_tls13_ticket( + &self, + server_name: &ServerName, + ) -> Option { + self.servers + .lock() + .unwrap() + .get_mut(server_name) + .and_then(|data| data.tls13.pop_front()) } } @@ -102,60 +179,57 @@ impl client::ResolvesClientCert for AlwaysResolvesClientCert { #[cfg(test)] mod test { use super::*; - use crate::client::StoresClientSessions; + use crate::client::ClientSessionStore; + use crate::internal::msgs::handshake::SessionID; + use std::convert::TryInto; + #[cfg(feature = "tls12")] #[test] - fn test_noclientsessionstorage_drops_put() { + fn test_noclientsessionstorage_does_nothing() { let c = NoClientSessionStorage {}; - assert!(!c.put(vec![0x01], vec![0x02])); - } + let name = "example.com".try_into().unwrap(); + let tls12_suite = match crate::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 { + crate::suites::SupportedCipherSuite::Tls12(inner) => inner, + _ => unreachable!(), + }; + let tls13_suite = match crate::cipher_suite::TLS13_AES_256_GCM_SHA384 { + crate::suites::SupportedCipherSuite::Tls13(inner) => inner, + _ => unreachable!(), + }; + let now = crate::ticketer::TimeBase::now().unwrap(); - #[test] - fn test_noclientsessionstorage_denies_gets() { - let c = NoClientSessionStorage {}; - c.put(vec![0x01], vec![0x02]); - assert_eq!(c.get(&[]), None); - assert_eq!(c.get(&[0x01]), None); - assert_eq!(c.get(&[0x02]), None); - } + c.set_kx_hint(&name, NamedGroup::X25519); + assert_eq!(None, c.kx_hint(&name)); - #[test] - fn test_clientsessionmemorycache_accepts_put() { - let c = ClientSessionMemoryCache::new(4); - assert!(c.put(vec![0x01], vec![0x02])); - } + c.set_tls12_session( + &name, + persist::Tls12ClientSessionValue::new( + &tls12_suite, + SessionID::empty(), + Vec::new(), + Vec::new(), + Vec::new(), + now, + 0, + true, + ), + ); + assert!(c.tls12_session(&name).is_none()); + c.remove_tls12_session(&name); - #[test] - fn test_clientsessionmemorycache_persists_put() { - let c = ClientSessionMemoryCache::new(4); - assert!(c.put(vec![0x01], vec![0x02])); - assert_eq!(c.get(&[0x01]), Some(vec![0x02])); - assert_eq!(c.get(&[0x01]), Some(vec![0x02])); - } - - #[test] - fn test_clientsessionmemorycache_overwrites_put() { - let c = ClientSessionMemoryCache::new(4); - assert!(c.put(vec![0x01], vec![0x02])); - assert!(c.put(vec![0x01], vec![0x04])); - assert_eq!(c.get(&[0x01]), Some(vec![0x04])); - } - - #[test] - fn test_clientsessionmemorycache_drops_to_maintain_size_invariant() { - let c = ClientSessionMemoryCache::new(2); - assert!(c.put(vec![0x01], vec![0x02])); - assert!(c.put(vec![0x03], vec![0x04])); - assert!(c.put(vec![0x05], vec![0x06])); - assert!(c.put(vec![0x07], vec![0x08])); - assert!(c.put(vec![0x09], vec![0x0a])); - - let count = c.get(&[0x01]).iter().count() - + c.get(&[0x03]).iter().count() - + c.get(&[0x05]).iter().count() - + c.get(&[0x07]).iter().count() - + c.get(&[0x09]).iter().count(); - - assert!(count < 5); + c.insert_tls13_ticket( + &name, + persist::Tls13ClientSessionValue::new( + &tls13_suite, + Vec::new(), + Vec::new(), + Vec::new(), + now, + 0, + 0, + 0, + ), + ); + assert!(c.take_tls13_ticket(&name).is_none()); } } diff --git a/rustls/src/client/hs.rs b/rustls/src/client/hs.rs index e325d7b3..a7c0efc7 100644 --- a/rustls/src/client/hs.rs +++ b/rustls/src/client/hs.rs @@ -9,7 +9,6 @@ use crate::kx; #[cfg(feature = "logging")] use crate::log::{debug, trace}; use crate::msgs::base::Payload; -use crate::msgs::codec::{Codec, Reader}; use crate::msgs::enums::{AlertDescription, Compression, ContentType}; use crate::msgs::enums::{ECPointFormat, PSKKeyExchangeMode}; use crate::msgs::enums::{ExtensionType, HandshakeType}; @@ -43,25 +42,16 @@ fn find_session( config: &ClientConfig, #[cfg(feature = "quic")] cx: &mut ClientContext<'_>, ) -> Option> { - let value = config + #[allow(clippy::let_and_return)] + let found = config .session_storage .take_tls13_ticket(server_name) + .map(persist::ClientSessionValue::from) .or_else(|| { config .session_storage .get_tls12_session(server_name) - }) - .or_else(|| { - debug!("No cached session for {:?}", server_name); - None - })?; - - #[allow(unused_mut)] - let mut reader = Reader::init(&value[2..]); - #[allow(clippy::bind_instead_of_map)] // https://github.com/rust-lang/rust-clippy/issues/8082 - CipherSuite::read_bytes(&value[..2]) - .and_then(|suite| { - persist::ClientSessionValue::read(&mut reader, suite, &config.cipher_suites) + .map(persist::ClientSessionValue::from) }) .and_then(|resuming| { let retrieved = persist::Retrieved::new(resuming, TimeBase::now().ok()?); @@ -70,15 +60,21 @@ fn find_session( true => None, } }) - .and_then(|resuming| { - #[cfg(feature = "quic")] - if cx.common.is_quic() { - cx.common.quic.params = resuming - .tls13() - .map(|v| v.quic_params()); - } - Some(resuming) - }) + .or_else(|| { + debug!("No cached session for {:?}", server_name); + None + }); + + #[cfg(feature = "quic")] + if let Some(resuming) = &found { + if cx.common.is_quic() { + cx.common.quic.params = resuming + .tls13() + .map(|v| v.quic_params()); + } + } + + found } pub(super) fn start_handshake( diff --git a/rustls/src/client/tls12.rs b/rustls/src/client/tls12.rs index cee69257..e1c62b24 100644 --- a/rustls/src/client/tls12.rs +++ b/rustls/src/client/tls12.rs @@ -1013,7 +1013,7 @@ impl ExpectFinished { self.config .session_storage - .put_tls12_session(&self.server_name, &stored_value); + .put_tls12_session(&self.server_name, stored_value); } } diff --git a/rustls/src/client/tls13.rs b/rustls/src/client/tls13.rs index 48f925a9..522c2be1 100644 --- a/rustls/src/client/tls13.rs +++ b/rustls/src/client/tls13.rs @@ -141,7 +141,7 @@ pub(super) fn handle_server_hello( // Remember what KX group the server liked for next time. config .session_storage - .put_kx_hint_for_server(&server_name, their_key_share.group); + .put_kx_hint(&server_name, their_key_share.group); // If we change keying when a subsequent handshake message is being joined, // the two halves will have different record layer protections. Disallow this. @@ -191,7 +191,7 @@ pub(super) fn initial_key_share( ) -> Result { let group = config .session_storage - .get_kx_hint_for_server(server_name) + .get_kx_hint(server_name) .and_then(|group| kx::KeyExchange::choose(group, &config.kx_groups)) .unwrap_or_else(|| { config @@ -970,7 +970,7 @@ impl ExpectTraffic { } self.session_storage - .add_tls13_ticket(&self.server_name, &value); + .add_tls13_ticket(&self.server_name, value); Ok(()) } diff --git a/rustls/src/lib.rs b/rustls/src/lib.rs index 00029aa2..766952c0 100644 --- a/rustls/src/lib.rs +++ b/rustls/src/lib.rs @@ -420,6 +420,8 @@ pub mod client { }; #[cfg(feature = "dangerous_configuration")] pub use client_conn::danger::DangerousClientConfig; + + pub use crate::msgs::persist::{Tls12ClientSessionValue, Tls13ClientSessionValue}; } pub use client::{ClientConfig, ClientConnection, ServerName}; diff --git a/rustls/src/limited_cache.rs b/rustls/src/limited_cache.rs index 6994d881..6bb99f95 100644 --- a/rustls/src/limited_cache.rs +++ b/rustls/src/limited_cache.rs @@ -21,6 +21,7 @@ pub(crate) struct LimitedCache { impl LimitedCache where K: Eq + Hash + Clone + std::fmt::Debug, + V: Default, { /// Create a new LimitedCache with the given rough capacity. pub(crate) fn new(capacity_order_of_magnitude: usize) -> Self { @@ -30,6 +31,28 @@ where } } + pub(crate) fn get_or_insert_default_and_edit(&mut self, k: K, edit: impl FnOnce(&mut V)) { + let inserted_new_item = match self.map.entry(k) { + Entry::Occupied(value) => { + edit(value.into_mut()); + false + } + entry @ Entry::Vacant(_) => { + self.oldest + .push_back(entry.key().clone()); + edit(entry.or_insert_with(V::default)); + true + } + }; + + // ensure next insertion does not require a realloc + if inserted_new_item && self.oldest.capacity() == self.oldest.len() { + if let Some(oldest_key) = self.oldest.pop_front() { + self.map.remove(&oldest_key); + } + } + } + pub(crate) fn insert(&mut self, k: K, v: V) { let inserted_new_item = match self.map.entry(k) { Entry::Occupied(mut old) => { @@ -46,7 +69,7 @@ where } }; - // ensure next insert() does not require a realloc + // ensure next insertion does not require a realloc if inserted_new_item && self.oldest.capacity() == self.oldest.len() { if let Some(oldest_key) = self.oldest.pop_front() { self.map.remove(&oldest_key); @@ -62,6 +85,14 @@ where self.map.get(k) } + pub(crate) fn get_mut(&mut self, k: &Q) -> Option<&mut V> + where + K: Borrow, + Q: Hash + Eq, + { + self.map.get_mut(k) + } + pub(crate) fn remove(&mut self, k: &Q) -> Option where K: Borrow, @@ -172,4 +203,43 @@ mod test { t.insert("ghi".into(), 3); } } + + #[test] + fn test_get_or_insert_default_and_edit_evicts_old_items_to_meet_capacity() { + let mut t = Test::new(3); + + t.get_or_insert_default_and_edit("abc".into(), |v| *v += 1); + t.get_or_insert_default_and_edit("def".into(), |v| *v += 2); + + // evicts "abc" + t.get_or_insert_default_and_edit("ghi".into(), |v| *v += 3); + assert_eq!(t.get("abc"), None); + + // evicts "def" + t.get_or_insert_default_and_edit("jkl".into(), |v| *v += 4); + assert_eq!(t.get("def"), None); + + // evicts "ghi" + t.get_or_insert_default_and_edit("abc".into(), |v| *v += 5); + assert_eq!(t.get("ghi"), None); + + // evicts "jkl" + t.get_or_insert_default_and_edit("def".into(), |v| *v += 6); + + assert_eq!(t.get("abc"), Some(&5)); + assert_eq!(t.get("def"), Some(&6)); + assert_eq!(t.get("ghi"), None); + assert_eq!(t.get("jkl"), None); + } + + #[test] + fn test_get_or_insert_default_and_edit_edits_existing_item() { + let mut t = Test::new(3); + + t.get_or_insert_default_and_edit("abc".into(), |v| *v += 1); + t.get_or_insert_default_and_edit("abc".into(), |v| *v += 2); + t.get_or_insert_default_and_edit("abc".into(), |v| *v += 3); + + assert_eq!(t.get("abc"), Some(&6)); + } } diff --git a/rustls/src/msgs/persist.rs b/rustls/src/msgs/persist.rs index aa132e18..5e46bed8 100644 --- a/rustls/src/msgs/persist.rs +++ b/rustls/src/msgs/persist.rs @@ -1,11 +1,9 @@ -use crate::client::ServerName; use crate::enums::{CipherSuite, ProtocolVersion}; use crate::key; use crate::msgs::base::{PayloadU16, PayloadU8}; use crate::msgs::codec::{Codec, Reader}; use crate::msgs::handshake::CertificatePayload; use crate::msgs::handshake::SessionID; -use crate::suites::SupportedCipherSuite; use crate::ticketer::TimeBase; #[cfg(feature = "tls12")] use crate::tls12::Tls12CipherSuite; @@ -15,45 +13,6 @@ use std::cmp; #[cfg(feature = "tls12")] use std::mem; -// These are the keys and values we store in session storage. - -// --- Client types --- -/// Keys for session resumption and tickets. -/// Matching value is a `ClientSessionValue`. -#[derive(Debug)] -pub struct ClientSessionKey { - kind: &'static [u8], - name: Vec, -} - -impl Codec for ClientSessionKey { - fn encode(&self, bytes: &mut Vec) { - bytes.extend_from_slice(self.kind); - bytes.extend_from_slice(&self.name); - } - - // Don't need to read these. - fn read(_r: &mut Reader) -> Option { - None - } -} - -impl ClientSessionKey { - pub fn session_for_server_name(server_name: &ServerName) -> Self { - Self { - kind: b"session", - name: server_name.encode(), - } - } - - pub fn hint_for_server_name(server_name: &ServerName) -> Self { - Self { - kind: b"kx-hint", - name: server_name.encode(), - } - } -} - #[derive(Debug)] pub enum ClientSessionValue { Tls13(Tls13ClientSessionValue), @@ -62,25 +21,6 @@ pub enum ClientSessionValue { } impl ClientSessionValue { - pub fn read( - reader: &mut Reader<'_>, - suite: CipherSuite, - supported: &[SupportedCipherSuite], - ) -> Option { - match supported - .iter() - .find(|s| s.suite() == suite)? - { - SupportedCipherSuite::Tls13(inner) => { - Tls13ClientSessionValue::read(inner, reader).map(ClientSessionValue::Tls13) - } - #[cfg(feature = "tls12")] - SupportedCipherSuite::Tls12(inner) => { - Tls12ClientSessionValue::read(inner, reader).map(ClientSessionValue::Tls12) - } - } - } - fn common(&self) -> &ClientSessionCommon { match self { Self::Tls13(inner) => &inner.common, @@ -192,39 +132,6 @@ impl Tls13ClientSessionValue { } } - /// [`Codec::read()`] with an extra `suite` argument. - /// - /// We decode the `suite` argument separately because it allows us to - /// decide whether we're decoding an 1.2 or 1.3 session value. - pub fn read(suite: &'static Tls13CipherSuite, r: &mut Reader) -> Option { - Some(Self { - suite, - age_add: u32::read(r)?, - max_early_data_size: u32::read(r)?, - common: ClientSessionCommon::read(r)?, - #[cfg(feature = "quic")] - quic_params: PayloadU16::read(r)?, - }) - } - - /// Inherent implementation of the [`Codec::get_encoding()`] method. - /// - /// (See `read()` for why this is inherent here.) - pub fn get_encoding(&self) -> Vec { - let mut bytes = Vec::with_capacity(16); - self.suite - .common - .suite - .encode(&mut bytes); - self.age_add.encode(&mut bytes); - self.max_early_data_size - .encode(&mut bytes); - self.common.encode(&mut bytes); - #[cfg(feature = "quic")] - self.quic_params.encode(&mut bytes); - bytes - } - pub fn max_early_data_size(&self) -> u32 { self.max_early_data_size } @@ -253,7 +160,7 @@ impl std::ops::Deref for Tls13ClientSessionValue { } #[cfg(feature = "tls12")] -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Tls12ClientSessionValue { suite: &'static Tls12CipherSuite, pub session_id: SessionID, @@ -287,34 +194,6 @@ impl Tls12ClientSessionValue { } } - /// [`Codec::read()`] with an extra `suite` argument. - /// - /// We decode the `suite` argument separately because it allows us to - /// decide whether we're decoding an 1.2 or 1.3 session value. - fn read(suite: &'static Tls12CipherSuite, r: &mut Reader) -> Option { - Some(Self { - suite, - session_id: SessionID::read(r)?, - extended_ms: u8::read(r)? == 1, - common: ClientSessionCommon::read(r)?, - }) - } - - /// Inherent implementation of the [`Codec::get_encoding()`] method. - /// - /// (See `read()` for why this is inherent here.) - pub fn get_encoding(&self) -> Vec { - let mut bytes = Vec::with_capacity(16); - self.suite - .common - .suite - .encode(&mut bytes); - self.session_id.encode(&mut bytes); - (u8::from(self.extended_ms)).encode(&mut bytes); - self.common.encode(&mut bytes); - bytes - } - pub fn take_ticket(&mut self) -> Vec { mem::take(&mut self.common.ticket.0) } @@ -337,7 +216,7 @@ impl std::ops::Deref for Tls12ClientSessionValue { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ClientSessionCommon { ticket: PayloadU16, secret: PayloadU8, @@ -363,30 +242,6 @@ impl ClientSessionCommon { } } - /// [`Codec::read()`] is inherent here to avoid leaking the [`Codec`] - /// implementation through [`Deref`] implementations on - /// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`]. - fn read(r: &mut Reader) -> Option { - Some(Self { - ticket: PayloadU16::read(r)?, - secret: PayloadU8::read(r)?, - epoch: u64::read(r)?, - lifetime_secs: u32::read(r)?, - server_cert_chain: CertificatePayload::read(r)?, - }) - } - - /// [`Codec::encode()`] is inherent here to avoid leaking the [`Codec`] - /// implementation through [`Deref`] implementations on - /// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`]. - fn encode(&self, bytes: &mut Vec) { - self.ticket.encode(bytes); - self.secret.encode(bytes); - self.epoch.encode(bytes); - self.lifetime_secs.encode(bytes); - self.server_cert_chain.encode(bytes); - } - pub fn server_cert_chain(&self) -> &[key::Certificate] { self.server_cert_chain.as_ref() } diff --git a/rustls/src/msgs/persist_test.rs b/rustls/src/msgs/persist_test.rs index 996da015..fd65d94c 100644 --- a/rustls/src/msgs/persist_test.rs +++ b/rustls/src/msgs/persist_test.rs @@ -6,20 +6,6 @@ use crate::key::Certificate; use crate::ticketer::TimeBase; use crate::tls13::TLS13_AES_128_GCM_SHA256; -#[test] -fn clientsessionkey_is_debug() { - let name = "hello".try_into().unwrap(); - let csk = ClientSessionKey::session_for_server_name(&name); - println!("{:?}", csk); -} - -#[test] -fn clientsessionkey_cannot_be_read() { - let bytes = [0; 1]; - let mut rd = Reader::init(&bytes); - assert!(ClientSessionKey::read(&mut rd).is_none()); -} - #[test] fn clientsessionvalue_is_debug() { let csv = ClientSessionValue::from(Tls13ClientSessionValue::new( diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index 1e3c1025..4ddf0a5c 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -2663,53 +2663,117 @@ impl rustls::server::StoresServerSessions for ServerStorage { } } +#[derive(Debug, Clone)] +enum ClientStorageOp { + PutKxHint(rustls::ServerName, rustls::NamedGroup), + GetKxHint(rustls::ServerName, Option), + PutTls12Session(rustls::ServerName), + GetTls12Session(rustls::ServerName, bool), + AddTls13Ticket(rustls::ServerName), + TakeTls13Ticket(rustls::ServerName, bool), +} + struct ClientStorage { storage: Arc, - put_count: AtomicUsize, - get_count: AtomicUsize, - last_put_key: Mutex>>, + ops: Mutex>, } impl ClientStorage { fn new() -> Self { ClientStorage { storage: rustls::client::ClientSessionMemoryCache::new(1024), - put_count: AtomicUsize::new(0), - get_count: AtomicUsize::new(0), - last_put_key: Mutex::new(None), + ops: Mutex::new(Vec::new()), } } - fn puts(&self) -> usize { - self.put_count.load(Ordering::SeqCst) - } - fn gets(&self) -> usize { - self.get_count.load(Ordering::SeqCst) + fn ops(&self) -> Vec { + self.ops.lock().unwrap().clone() } } impl fmt::Debug for ClientStorage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "(puts: {:?}, gets: {:?} )", - self.put_count, self.get_count - ) + write!(f, "(ops: {:?})", self.ops.lock().unwrap()) } } impl rustls::client::StoresClientSessions for ClientStorage { - fn put(&self, key: Vec, value: Vec) -> bool { - self.put_count - .fetch_add(1, Ordering::SeqCst); - *self.last_put_key.lock().unwrap() = Some(key.clone()); - self.storage.put(key, value) + fn put_kx_hint(&self, server_name: &rustls::ServerName, group: rustls::NamedGroup) { + self.ops + .lock() + .unwrap() + .push(ClientStorageOp::PutKxHint(server_name.clone(), group)); + self.storage + .put_kx_hint(server_name, group) } - fn get(&self, key: &[u8]) -> Option> { - self.get_count - .fetch_add(1, Ordering::SeqCst); - self.storage.get(key) + fn get_kx_hint(&self, server_name: &rustls::ServerName) -> Option { + let rc = self.storage.get_kx_hint(server_name); + self.ops + .lock() + .unwrap() + .push(ClientStorageOp::GetKxHint(server_name.clone(), rc)); + rc + } + + fn put_tls12_session( + &self, + server_name: &rustls::ServerName, + value: rustls::client::Tls12ClientSessionValue, + ) { + self.ops + .lock() + .unwrap() + .push(ClientStorageOp::PutTls12Session(server_name.clone())); + self.storage + .put_tls12_session(server_name, value) + } + + fn get_tls12_session( + &self, + server_name: &rustls::ServerName, + ) -> Option { + let rc = self + .storage + .get_tls12_session(server_name); + self.ops + .lock() + .unwrap() + .push(ClientStorageOp::GetTls12Session( + server_name.clone(), + rc.is_some(), + )); + rc + } + + fn add_tls13_ticket( + &self, + server_name: &rustls::ServerName, + value: rustls::client::Tls13ClientSessionValue, + ) { + self.ops + .lock() + .unwrap() + .push(ClientStorageOp::AddTls13Ticket(server_name.clone())); + self.storage + .add_tls13_ticket(server_name, value); + } + + fn take_tls13_ticket( + &self, + server_name: &rustls::ServerName, + ) -> Option { + let rc = self + .storage + .take_tls13_ticket(server_name); + self.ops + .lock() + .unwrap() + .push(ClientStorageOp::TakeTls13Ticket( + server_name.clone(), + rc.is_some(), + )); + rc } } @@ -3105,6 +3169,7 @@ mod test_quic { server_params.into(), ) .unwrap(); + server.reject_early_data(); step(&mut client, &mut server).unwrap(); assert_eq!(client.quic_transport_parameters(), Some(server_params)); @@ -3705,9 +3770,29 @@ fn test_client_sends_helloretryrequest() { do_handshake_until_error(&mut client, &mut server).unwrap(); - // client only did three storage queries: two for sessions, another for a kx type - assert_eq!(storage.gets(), 3); - assert_eq!(storage.puts(), 2); + // client only did following storage queries: + println!("storage {:#?}", storage.ops()); + assert_eq!(storage.ops().len(), 5); + assert!(matches!( + storage.ops()[0], + ClientStorageOp::TakeTls13Ticket(_, false) + )); + assert!(matches!( + storage.ops()[1], + ClientStorageOp::GetTls12Session(_, false) + )); + assert!(matches!( + storage.ops()[2], + ClientStorageOp::GetKxHint(_, None) + )); + assert!(matches!( + storage.ops()[3], + ClientStorageOp::PutKxHint(_, rustls::NamedGroup::X25519) + )); + assert!(matches!( + storage.ops()[4], + ClientStorageOp::AddTls13Ticket(_) + )); } #[test] @@ -4019,14 +4104,37 @@ fn test_client_tls12_no_resume_after_server_downgrade() { ClientConnection::new(client_config.clone(), "localhost".try_into().unwrap()).unwrap(); let mut server_1 = ServerConnection::new(server_config_1).unwrap(); common::do_handshake(&mut client_1, &mut server_1); - assert_eq!(client_storage.puts(), 2); + + assert_eq!(client_storage.ops().len(), 5); + println!("hs1 storage ops: {:#?}", client_storage.ops()); + assert!(matches!( + client_storage.ops()[3], + ClientStorageOp::PutKxHint(_, _) + )); + assert!(matches!( + client_storage.ops()[4], + ClientStorageOp::AddTls13Ticket(_) + )); dbg!("handshake 2"); let mut client_2 = ClientConnection::new(client_config, "localhost".try_into().unwrap()).unwrap(); let mut server_2 = ServerConnection::new(Arc::new(server_config_2)).unwrap(); common::do_handshake(&mut client_2, &mut server_2); - assert_eq!(client_storage.puts(), 3); + println!("hs2 storage ops: {:#?}", client_storage.ops()); + assert_eq!(client_storage.ops().len(), 7); + + // attempt consumes a TLS1.3 ticket + assert!(matches!( + client_storage.ops()[5], + ClientStorageOp::TakeTls13Ticket(_, true) + )); + + // but ends up with TLS1.2 + assert_eq!( + client_2.protocol_version(), + Some(rustls::ProtocolVersion::TLSv1_2) + ); } #[test]