Bring back support for encoding/decoding client session values

This commit is contained in:
Dirkjan Ochtman 2023-03-21 16:19:54 +01:00
parent a44d1669fd
commit 0bcdf119c5
3 changed files with 127 additions and 5 deletions

View File

@ -1,4 +1,4 @@
use crate::enums::{AlertDescription, ContentType, HandshakeType};
use crate::enums::{AlertDescription, CipherSuite, ContentType, HandshakeType};
use crate::msgs::enums::{CertificateStatusType, ECCurveType, KeyUpdateRequest};
use crate::msgs::handshake::KeyExchangeAlgorithm;
use crate::rand;
@ -132,6 +132,8 @@ pub enum InvalidMessage {
UnexpectedMessage(&'static str),
/// An unknown TLS protocol was encountered during message decoding.
UnknownProtocolVersion,
/// An unsupported cipher suite was encountered during message decoding.
UnsupportedCipherSuite(CipherSuite),
/// A peer sent a non-null compression method.
UnsupportedCompression,
/// A peer sent an unknown elliptic curve type.

View File

@ -65,7 +65,7 @@ impl<T> std::ops::Deref for Retrieved<T> {
}
}
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct Tls13ClientSessionValue {
suite: &'static Tls13CipherSuite,
age_add: u32,
@ -102,6 +102,41 @@ impl Tls13ClientSessionValue {
}
}
/// [`Codec::read()`] with an extra `suites` argument.
///
/// We need a reference to the available cipher suites in order to check
/// that the cipher suite in the ticket is supported.
pub fn read(
r: &mut Reader,
mut suites: impl Iterator<Item = &'static Tls13CipherSuite>,
) -> Result<Self, InvalidMessage> {
let suite = CipherSuite::read(r)?;
let suite = suites
.find(|s| s.common.suite == suite)
.ok_or(InvalidMessage::UnsupportedCipherSuite(suite))?;
Ok(Self {
suite,
age_add: u32::read(r)?,
max_early_data_size: u32::read(r)?,
common: ClientSessionCommon::read(r)?,
#[cfg(feature = "quic")]
quic_params: PayloadU16::read(r)?,
})
}
/// Inherent implementation of the [`Codec::encode()`] method.
///
/// (See `read()` for why this is inherent here.)
pub fn encode(&self, bytes: &mut Vec<u8>) {
self.suite.common.suite.encode(bytes);
self.age_add.encode(bytes);
self.max_early_data_size.encode(bytes);
self.common.encode(bytes);
#[cfg(feature = "quic")]
self.quic_params.encode(bytes);
}
pub fn max_early_data_size(&self) -> u32 {
self.max_early_data_size
}
@ -129,7 +164,7 @@ impl std::ops::Deref for Tls13ClientSessionValue {
}
}
#[derive(Debug, Clone)]
#[derive(Clone, Debug, PartialEq)]
pub struct Tls12ClientSessionValue {
#[cfg(feature = "tls12")]
suite: &'static Tls12CipherSuite,
@ -166,6 +201,37 @@ impl Tls12ClientSessionValue {
}
}
/// [`Codec::read()`] with an extra `suites` argument.
///
/// We need a reference to the available cipher suites in order to check
/// that the cipher suite in the ticket is supported.
pub fn read(
r: &mut Reader,
mut suites: impl Iterator<Item = &'static Tls12CipherSuite>,
) -> Result<Self, InvalidMessage> {
let suite = CipherSuite::read(r)?;
let suite = suites
.find(|s| s.common.suite == suite)
.ok_or(InvalidMessage::UnsupportedCipherSuite(suite))?;
Ok(Self {
suite,
session_id: SessionID::read(r)?,
extended_ms: u8::read(r)? == 1,
common: ClientSessionCommon::read(r)?,
})
}
/// Inherent implementation of the [`Codec::encode()`] method.
///
/// (See `read()` for why this is inherent here.)
pub fn encode(&self, bytes: &mut Vec<u8>) {
self.suite.common.suite.encode(bytes);
self.session_id.encode(bytes);
(u8::from(self.extended_ms)).encode(bytes);
self.common.encode(bytes);
}
pub fn take_ticket(&mut self) -> Vec<u8> {
mem::take(&mut self.common.ticket.0)
}
@ -188,7 +254,7 @@ impl std::ops::Deref for Tls12ClientSessionValue {
}
}
#[derive(Debug, Clone)]
#[derive(Clone, Debug, PartialEq)]
pub struct ClientSessionCommon {
ticket: PayloadU16,
secret: PayloadU8,
@ -214,6 +280,30 @@ impl ClientSessionCommon {
}
}
/// [`Codec::read()`] is inherent here to avoid leaking the [`Codec`]
/// implementation through [`Deref`] implementations on
/// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`].
fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
Ok(Self {
ticket: PayloadU16::read(r)?,
secret: PayloadU8::read(r)?,
epoch: u64::read(r)?,
lifetime_secs: u32::read(r)?,
server_cert_chain: CertificatePayload::read(r)?,
})
}
/// [`Codec::encode()`] is inherent here to avoid leaking the [`Codec`]
/// implementation through [`Deref`] implementations on
/// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`].
fn encode(&self, bytes: &mut Vec<u8>) {
self.ticket.encode(bytes);
self.secret.encode(bytes);
self.epoch.encode(bytes);
self.lifetime_secs.encode(bytes);
self.server_cert_chain.encode(bytes);
}
pub fn server_cert_chain(&self) -> &[key::Certificate] {
self.server_cert_chain.as_ref()
}

View File

@ -10,7 +10,7 @@ use std::sync::Mutex;
use rustls::client::ResolvesClientCert;
use rustls::internal::msgs::base::Payload;
use rustls::internal::msgs::codec::Codec;
use rustls::internal::msgs::codec::{Codec, Reader};
#[cfg(feature = "quic")]
use rustls::quic::{self, ClientQuicExt, QuicExt, ServerQuicExt};
use rustls::server::{AllowAnyAnonymousOrAuthenticatedClient, ClientHello, ResolvesServerCert};
@ -2730,6 +2730,22 @@ impl rustls::client::ClientSessionStore for ClientStorage {
server_name: &rustls::ServerName,
value: rustls::client::Tls12ClientSessionValue,
) {
#[cfg(feature = "tls12")]
{
let mut bytes = Vec::new();
value.encode(&mut bytes);
let mut reader = Reader::init(&bytes);
let tls12_suites = ALL_CIPHER_SUITES
.iter()
.filter_map(|&suite| match suite {
rustls::SupportedCipherSuite::Tls12(suite) => Some(suite),
_ => None,
});
let decoded =
rustls::client::Tls12ClientSessionValue::read(&mut reader, tls12_suites).unwrap();
assert_eq!(value, decoded);
}
self.ops
.lock()
.unwrap()
@ -2767,6 +2783,20 @@ impl rustls::client::ClientSessionStore for ClientStorage {
server_name: &rustls::ServerName,
value: rustls::client::Tls13ClientSessionValue,
) {
let mut bytes = Vec::new();
value.encode(&mut bytes);
let mut reader = Reader::init(&bytes);
let tls13_suites = ALL_CIPHER_SUITES
.iter()
.filter_map(|&suite| match suite {
rustls::SupportedCipherSuite::Tls13(suite) => Some(suite),
#[cfg(feature = "tls12")]
_ => None,
});
let decoded =
rustls::client::Tls13ClientSessionValue::read(&mut reader, tls13_suites).unwrap();
assert_eq!(value, decoded);
self.ops
.lock()
.unwrap()