From 0bcdf119c573498ca5a2e4b84c92771480a8547e Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Tue, 21 Mar 2023 16:19:54 +0100 Subject: [PATCH] Bring back support for encoding/decoding client session values --- rustls/src/error.rs | 4 +- rustls/src/msgs/persist.rs | 96 ++++++++++++++++++++++++++++++++++++-- rustls/tests/api.rs | 32 ++++++++++++- 3 files changed, 127 insertions(+), 5 deletions(-) diff --git a/rustls/src/error.rs b/rustls/src/error.rs index 6a269f35..bb14d2bd 100644 --- a/rustls/src/error.rs +++ b/rustls/src/error.rs @@ -1,4 +1,4 @@ -use crate::enums::{AlertDescription, ContentType, HandshakeType}; +use crate::enums::{AlertDescription, CipherSuite, ContentType, HandshakeType}; use crate::msgs::enums::{CertificateStatusType, ECCurveType, KeyUpdateRequest}; use crate::msgs::handshake::KeyExchangeAlgorithm; use crate::rand; @@ -132,6 +132,8 @@ pub enum InvalidMessage { UnexpectedMessage(&'static str), /// An unknown TLS protocol was encountered during message decoding. UnknownProtocolVersion, + /// An unsupported cipher suite was encountered during message decoding. + UnsupportedCipherSuite(CipherSuite), /// A peer sent a non-null compression method. UnsupportedCompression, /// A peer sent an unknown elliptic curve type. diff --git a/rustls/src/msgs/persist.rs b/rustls/src/msgs/persist.rs index 6e3f12e4..ee915635 100644 --- a/rustls/src/msgs/persist.rs +++ b/rustls/src/msgs/persist.rs @@ -65,7 +65,7 @@ impl std::ops::Deref for Retrieved { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct Tls13ClientSessionValue { suite: &'static Tls13CipherSuite, age_add: u32, @@ -102,6 +102,41 @@ impl Tls13ClientSessionValue { } } + /// [`Codec::read()`] with an extra `suites` argument. + /// + /// We need a reference to the available cipher suites in order to check + /// that the cipher suite in the ticket is supported. + pub fn read( + r: &mut Reader, + mut suites: impl Iterator, + ) -> Result { + let suite = CipherSuite::read(r)?; + let suite = suites + .find(|s| s.common.suite == suite) + .ok_or(InvalidMessage::UnsupportedCipherSuite(suite))?; + + Ok(Self { + suite, + age_add: u32::read(r)?, + max_early_data_size: u32::read(r)?, + common: ClientSessionCommon::read(r)?, + #[cfg(feature = "quic")] + quic_params: PayloadU16::read(r)?, + }) + } + + /// Inherent implementation of the [`Codec::encode()`] method. + /// + /// (See `read()` for why this is inherent here.) + pub fn encode(&self, bytes: &mut Vec) { + self.suite.common.suite.encode(bytes); + self.age_add.encode(bytes); + self.max_early_data_size.encode(bytes); + self.common.encode(bytes); + #[cfg(feature = "quic")] + self.quic_params.encode(bytes); + } + pub fn max_early_data_size(&self) -> u32 { self.max_early_data_size } @@ -129,7 +164,7 @@ impl std::ops::Deref for Tls13ClientSessionValue { } } -#[derive(Debug, Clone)] +#[derive(Clone, Debug, PartialEq)] pub struct Tls12ClientSessionValue { #[cfg(feature = "tls12")] suite: &'static Tls12CipherSuite, @@ -166,6 +201,37 @@ impl Tls12ClientSessionValue { } } + /// [`Codec::read()`] with an extra `suites` argument. + /// + /// We need a reference to the available cipher suites in order to check + /// that the cipher suite in the ticket is supported. + pub fn read( + r: &mut Reader, + mut suites: impl Iterator, + ) -> Result { + let suite = CipherSuite::read(r)?; + let suite = suites + .find(|s| s.common.suite == suite) + .ok_or(InvalidMessage::UnsupportedCipherSuite(suite))?; + + Ok(Self { + suite, + session_id: SessionID::read(r)?, + extended_ms: u8::read(r)? == 1, + common: ClientSessionCommon::read(r)?, + }) + } + + /// Inherent implementation of the [`Codec::encode()`] method. + /// + /// (See `read()` for why this is inherent here.) + pub fn encode(&self, bytes: &mut Vec) { + self.suite.common.suite.encode(bytes); + self.session_id.encode(bytes); + (u8::from(self.extended_ms)).encode(bytes); + self.common.encode(bytes); + } + pub fn take_ticket(&mut self) -> Vec { mem::take(&mut self.common.ticket.0) } @@ -188,7 +254,7 @@ impl std::ops::Deref for Tls12ClientSessionValue { } } -#[derive(Debug, Clone)] +#[derive(Clone, Debug, PartialEq)] pub struct ClientSessionCommon { ticket: PayloadU16, secret: PayloadU8, @@ -214,6 +280,30 @@ impl ClientSessionCommon { } } + /// [`Codec::read()`] is inherent here to avoid leaking the [`Codec`] + /// implementation through [`Deref`] implementations on + /// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`]. + fn read(r: &mut Reader) -> Result { + Ok(Self { + ticket: PayloadU16::read(r)?, + secret: PayloadU8::read(r)?, + epoch: u64::read(r)?, + lifetime_secs: u32::read(r)?, + server_cert_chain: CertificatePayload::read(r)?, + }) + } + + /// [`Codec::encode()`] is inherent here to avoid leaking the [`Codec`] + /// implementation through [`Deref`] implementations on + /// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`]. + fn encode(&self, bytes: &mut Vec) { + self.ticket.encode(bytes); + self.secret.encode(bytes); + self.epoch.encode(bytes); + self.lifetime_secs.encode(bytes); + self.server_cert_chain.encode(bytes); + } + pub fn server_cert_chain(&self) -> &[key::Certificate] { self.server_cert_chain.as_ref() } diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index a3e11848..4e062796 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -10,7 +10,7 @@ use std::sync::Mutex; use rustls::client::ResolvesClientCert; use rustls::internal::msgs::base::Payload; -use rustls::internal::msgs::codec::Codec; +use rustls::internal::msgs::codec::{Codec, Reader}; #[cfg(feature = "quic")] use rustls::quic::{self, ClientQuicExt, QuicExt, ServerQuicExt}; use rustls::server::{AllowAnyAnonymousOrAuthenticatedClient, ClientHello, ResolvesServerCert}; @@ -2730,6 +2730,22 @@ impl rustls::client::ClientSessionStore for ClientStorage { server_name: &rustls::ServerName, value: rustls::client::Tls12ClientSessionValue, ) { + #[cfg(feature = "tls12")] + { + let mut bytes = Vec::new(); + value.encode(&mut bytes); + let mut reader = Reader::init(&bytes); + let tls12_suites = ALL_CIPHER_SUITES + .iter() + .filter_map(|&suite| match suite { + rustls::SupportedCipherSuite::Tls12(suite) => Some(suite), + _ => None, + }); + let decoded = + rustls::client::Tls12ClientSessionValue::read(&mut reader, tls12_suites).unwrap(); + assert_eq!(value, decoded); + } + self.ops .lock() .unwrap() @@ -2767,6 +2783,20 @@ impl rustls::client::ClientSessionStore for ClientStorage { server_name: &rustls::ServerName, value: rustls::client::Tls13ClientSessionValue, ) { + let mut bytes = Vec::new(); + value.encode(&mut bytes); + let mut reader = Reader::init(&bytes); + let tls13_suites = ALL_CIPHER_SUITES + .iter() + .filter_map(|&suite| match suite { + rustls::SupportedCipherSuite::Tls13(suite) => Some(suite), + #[cfg(feature = "tls12")] + _ => None, + }); + let decoded = + rustls::client::Tls13ClientSessionValue::read(&mut reader, tls13_suites).unwrap(); + assert_eq!(value, decoded); + self.ops .lock() .unwrap()