mirror of https://github.com/ctz/rustls
Bring back support for encoding/decoding client session values
This commit is contained in:
parent
a44d1669fd
commit
0bcdf119c5
|
@ -1,4 +1,4 @@
|
|||
use crate::enums::{AlertDescription, ContentType, HandshakeType};
|
||||
use crate::enums::{AlertDescription, CipherSuite, ContentType, HandshakeType};
|
||||
use crate::msgs::enums::{CertificateStatusType, ECCurveType, KeyUpdateRequest};
|
||||
use crate::msgs::handshake::KeyExchangeAlgorithm;
|
||||
use crate::rand;
|
||||
|
@ -132,6 +132,8 @@ pub enum InvalidMessage {
|
|||
UnexpectedMessage(&'static str),
|
||||
/// An unknown TLS protocol was encountered during message decoding.
|
||||
UnknownProtocolVersion,
|
||||
/// An unsupported cipher suite was encountered during message decoding.
|
||||
UnsupportedCipherSuite(CipherSuite),
|
||||
/// A peer sent a non-null compression method.
|
||||
UnsupportedCompression,
|
||||
/// A peer sent an unknown elliptic curve type.
|
||||
|
|
|
@ -65,7 +65,7 @@ impl<T> std::ops::Deref for Retrieved<T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct Tls13ClientSessionValue {
|
||||
suite: &'static Tls13CipherSuite,
|
||||
age_add: u32,
|
||||
|
@ -102,6 +102,41 @@ impl Tls13ClientSessionValue {
|
|||
}
|
||||
}
|
||||
|
||||
/// [`Codec::read()`] with an extra `suites` argument.
|
||||
///
|
||||
/// We need a reference to the available cipher suites in order to check
|
||||
/// that the cipher suite in the ticket is supported.
|
||||
pub fn read(
|
||||
r: &mut Reader,
|
||||
mut suites: impl Iterator<Item = &'static Tls13CipherSuite>,
|
||||
) -> Result<Self, InvalidMessage> {
|
||||
let suite = CipherSuite::read(r)?;
|
||||
let suite = suites
|
||||
.find(|s| s.common.suite == suite)
|
||||
.ok_or(InvalidMessage::UnsupportedCipherSuite(suite))?;
|
||||
|
||||
Ok(Self {
|
||||
suite,
|
||||
age_add: u32::read(r)?,
|
||||
max_early_data_size: u32::read(r)?,
|
||||
common: ClientSessionCommon::read(r)?,
|
||||
#[cfg(feature = "quic")]
|
||||
quic_params: PayloadU16::read(r)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Inherent implementation of the [`Codec::encode()`] method.
|
||||
///
|
||||
/// (See `read()` for why this is inherent here.)
|
||||
pub fn encode(&self, bytes: &mut Vec<u8>) {
|
||||
self.suite.common.suite.encode(bytes);
|
||||
self.age_add.encode(bytes);
|
||||
self.max_early_data_size.encode(bytes);
|
||||
self.common.encode(bytes);
|
||||
#[cfg(feature = "quic")]
|
||||
self.quic_params.encode(bytes);
|
||||
}
|
||||
|
||||
pub fn max_early_data_size(&self) -> u32 {
|
||||
self.max_early_data_size
|
||||
}
|
||||
|
@ -129,7 +164,7 @@ impl std::ops::Deref for Tls13ClientSessionValue {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct Tls12ClientSessionValue {
|
||||
#[cfg(feature = "tls12")]
|
||||
suite: &'static Tls12CipherSuite,
|
||||
|
@ -166,6 +201,37 @@ impl Tls12ClientSessionValue {
|
|||
}
|
||||
}
|
||||
|
||||
/// [`Codec::read()`] with an extra `suites` argument.
|
||||
///
|
||||
/// We need a reference to the available cipher suites in order to check
|
||||
/// that the cipher suite in the ticket is supported.
|
||||
pub fn read(
|
||||
r: &mut Reader,
|
||||
mut suites: impl Iterator<Item = &'static Tls12CipherSuite>,
|
||||
) -> Result<Self, InvalidMessage> {
|
||||
let suite = CipherSuite::read(r)?;
|
||||
let suite = suites
|
||||
.find(|s| s.common.suite == suite)
|
||||
.ok_or(InvalidMessage::UnsupportedCipherSuite(suite))?;
|
||||
|
||||
Ok(Self {
|
||||
suite,
|
||||
session_id: SessionID::read(r)?,
|
||||
extended_ms: u8::read(r)? == 1,
|
||||
common: ClientSessionCommon::read(r)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Inherent implementation of the [`Codec::encode()`] method.
|
||||
///
|
||||
/// (See `read()` for why this is inherent here.)
|
||||
pub fn encode(&self, bytes: &mut Vec<u8>) {
|
||||
self.suite.common.suite.encode(bytes);
|
||||
self.session_id.encode(bytes);
|
||||
(u8::from(self.extended_ms)).encode(bytes);
|
||||
self.common.encode(bytes);
|
||||
}
|
||||
|
||||
pub fn take_ticket(&mut self) -> Vec<u8> {
|
||||
mem::take(&mut self.common.ticket.0)
|
||||
}
|
||||
|
@ -188,7 +254,7 @@ impl std::ops::Deref for Tls12ClientSessionValue {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct ClientSessionCommon {
|
||||
ticket: PayloadU16,
|
||||
secret: PayloadU8,
|
||||
|
@ -214,6 +280,30 @@ impl ClientSessionCommon {
|
|||
}
|
||||
}
|
||||
|
||||
/// [`Codec::read()`] is inherent here to avoid leaking the [`Codec`]
|
||||
/// implementation through [`Deref`] implementations on
|
||||
/// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`].
|
||||
fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
|
||||
Ok(Self {
|
||||
ticket: PayloadU16::read(r)?,
|
||||
secret: PayloadU8::read(r)?,
|
||||
epoch: u64::read(r)?,
|
||||
lifetime_secs: u32::read(r)?,
|
||||
server_cert_chain: CertificatePayload::read(r)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// [`Codec::encode()`] is inherent here to avoid leaking the [`Codec`]
|
||||
/// implementation through [`Deref`] implementations on
|
||||
/// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`].
|
||||
fn encode(&self, bytes: &mut Vec<u8>) {
|
||||
self.ticket.encode(bytes);
|
||||
self.secret.encode(bytes);
|
||||
self.epoch.encode(bytes);
|
||||
self.lifetime_secs.encode(bytes);
|
||||
self.server_cert_chain.encode(bytes);
|
||||
}
|
||||
|
||||
pub fn server_cert_chain(&self) -> &[key::Certificate] {
|
||||
self.server_cert_chain.as_ref()
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ use std::sync::Mutex;
|
|||
|
||||
use rustls::client::ResolvesClientCert;
|
||||
use rustls::internal::msgs::base::Payload;
|
||||
use rustls::internal::msgs::codec::Codec;
|
||||
use rustls::internal::msgs::codec::{Codec, Reader};
|
||||
#[cfg(feature = "quic")]
|
||||
use rustls::quic::{self, ClientQuicExt, QuicExt, ServerQuicExt};
|
||||
use rustls::server::{AllowAnyAnonymousOrAuthenticatedClient, ClientHello, ResolvesServerCert};
|
||||
|
@ -2730,6 +2730,22 @@ impl rustls::client::ClientSessionStore for ClientStorage {
|
|||
server_name: &rustls::ServerName,
|
||||
value: rustls::client::Tls12ClientSessionValue,
|
||||
) {
|
||||
#[cfg(feature = "tls12")]
|
||||
{
|
||||
let mut bytes = Vec::new();
|
||||
value.encode(&mut bytes);
|
||||
let mut reader = Reader::init(&bytes);
|
||||
let tls12_suites = ALL_CIPHER_SUITES
|
||||
.iter()
|
||||
.filter_map(|&suite| match suite {
|
||||
rustls::SupportedCipherSuite::Tls12(suite) => Some(suite),
|
||||
_ => None,
|
||||
});
|
||||
let decoded =
|
||||
rustls::client::Tls12ClientSessionValue::read(&mut reader, tls12_suites).unwrap();
|
||||
assert_eq!(value, decoded);
|
||||
}
|
||||
|
||||
self.ops
|
||||
.lock()
|
||||
.unwrap()
|
||||
|
@ -2767,6 +2783,20 @@ impl rustls::client::ClientSessionStore for ClientStorage {
|
|||
server_name: &rustls::ServerName,
|
||||
value: rustls::client::Tls13ClientSessionValue,
|
||||
) {
|
||||
let mut bytes = Vec::new();
|
||||
value.encode(&mut bytes);
|
||||
let mut reader = Reader::init(&bytes);
|
||||
let tls13_suites = ALL_CIPHER_SUITES
|
||||
.iter()
|
||||
.filter_map(|&suite| match suite {
|
||||
rustls::SupportedCipherSuite::Tls13(suite) => Some(suite),
|
||||
#[cfg(feature = "tls12")]
|
||||
_ => None,
|
||||
});
|
||||
let decoded =
|
||||
rustls::client::Tls13ClientSessionValue::read(&mut reader, tls13_suites).unwrap();
|
||||
assert_eq!(value, decoded);
|
||||
|
||||
self.ops
|
||||
.lock()
|
||||
.unwrap()
|
||||
|
|
Loading…
Reference in New Issue