Make ClientSessionValue private

This commit is contained in:
Dirkjan Ochtman 2023-03-21 15:26:35 +01:00
parent 0b0c7b7a9d
commit 1538c460b8
2 changed files with 56 additions and 75 deletions

View File

@ -31,6 +31,7 @@ use crate::client::client_conn::ClientConnectionData;
use crate::client::common::ClientHelloDetails;
use crate::client::{tls13, ClientConfig, ServerName};
use std::ops::Deref;
use std::sync::Arc;
pub(super) type NextState = Box<dyn State<ClientConnectionData>>;
@ -41,19 +42,19 @@ fn find_session(
server_name: &ServerName,
config: &ClientConfig,
#[cfg(feature = "quic")] cx: &mut ClientContext<'_>,
) -> Option<persist::Retrieved<persist::ClientSessionValue>> {
) -> Option<persist::Retrieved<ClientSessionValue>> {
#[allow(clippy::let_and_return)]
let found = config
.session_storage
.take_tls13_ticket(server_name)
.map(persist::ClientSessionValue::from)
.map(ClientSessionValue::Tls13)
.or_else(|| {
#[cfg(feature = "tls12")]
{
config
.session_storage
.tls12_session(server_name)
.map(persist::ClientSessionValue::from)
.map(ClientSessionValue::Tls12)
}
#[cfg(not(feature = "tls12"))]
@ -115,7 +116,7 @@ pub(super) fn start_handshake(
if let Some(_resuming) = &mut resuming_session {
#[cfg(feature = "tls12")]
if let persist::ClientSessionValue::Tls12(inner) = &mut _resuming.value {
if let ClientSessionValue::Tls12(inner) = &mut _resuming.value {
// If we have a ticket, we use the sessionid as a signal that
// we're doing an abbreviated handshake. See section 3.4 in
// RFC5077.
@ -161,7 +162,7 @@ pub(super) fn start_handshake(
struct ExpectServerHello {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
resuming_session: Option<persist::Retrieved<ClientSessionValue>>,
server_name: ServerName,
random: Random,
using_ems: bool,
@ -182,7 +183,7 @@ struct ExpectServerHelloOrHelloRetryRequest {
fn emit_client_hello_for_retry(
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
resuming_session: Option<persist::Retrieved<ClientSessionValue>>,
random: Random,
using_ems: bool,
mut transcript_buffer: HandshakeHashBuffer,
@ -199,13 +200,9 @@ fn emit_client_hello_for_retry(
// Do we have a SessionID or ticket cached for this host?
let (ticket, resume_version) = if let Some(resuming) = &resuming_session {
match &resuming.value {
persist::ClientSessionValue::Tls13(inner) => {
(inner.ticket().to_vec(), ProtocolVersion::TLSv1_3)
}
ClientSessionValue::Tls13(inner) => (inner.ticket().to_vec(), ProtocolVersion::TLSv1_3),
#[cfg(feature = "tls12")]
persist::ClientSessionValue::Tls12(inner) => {
(inner.ticket().to_vec(), ProtocolVersion::TLSv1_2)
}
ClientSessionValue::Tls12(inner) => (inner.ticket().to_vec(), ProtocolVersion::TLSv1_2),
}
} else {
(Vec::new(), ProtocolVersion::Unknown(0))
@ -595,9 +592,9 @@ impl State<ClientConnectionData> for ExpectServerHello {
let resuming_session = self
.resuming_session
.and_then(|resuming| match resuming.value {
persist::ClientSessionValue::Tls13(inner) => Some(inner),
ClientSessionValue::Tls13(inner) => Some(inner),
#[cfg(feature = "tls12")]
persist::ClientSessionValue::Tls12(_) => None,
ClientSessionValue::Tls12(_) => None,
});
tls13::handle_server_hello(
@ -621,8 +618,8 @@ impl State<ClientConnectionData> for ExpectServerHello {
let resuming_session = self
.resuming_session
.and_then(|resuming| match resuming.value {
persist::ClientSessionValue::Tls12(inner) => Some(inner),
persist::ClientSessionValue::Tls13(_) => None,
ClientSessionValue::Tls12(inner) => Some(inner),
ClientSessionValue::Tls13(_) => None,
});
tls12::CompleteServerHelloHandling {
@ -811,3 +808,37 @@ impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
}
}
}
enum ClientSessionValue {
Tls13(persist::Tls13ClientSessionValue),
#[cfg(feature = "tls12")]
Tls12(persist::Tls12ClientSessionValue),
}
impl ClientSessionValue {
fn common(&self) -> &persist::ClientSessionCommon {
match self {
Self::Tls13(inner) => &inner.common,
#[cfg(feature = "tls12")]
Self::Tls12(inner) => &inner.common,
}
}
}
impl persist::Retrieved<ClientSessionValue> {
fn tls13(&self) -> Option<persist::Retrieved<&persist::Tls13ClientSessionValue>> {
self.map(|v| match v {
ClientSessionValue::Tls13(v) => Some(v),
#[cfg(feature = "tls12")]
ClientSessionValue::Tls12(_) => None,
})
}
}
impl Deref for ClientSessionValue {
type Target = persist::ClientSessionCommon;
fn deref(&self) -> &Self::Target {
self.common()
}
}

View File

@ -14,36 +14,6 @@ use std::cmp;
#[cfg(feature = "tls12")]
use std::mem;
#[derive(Debug)]
pub enum ClientSessionValue {
Tls13(Tls13ClientSessionValue),
#[cfg(feature = "tls12")]
Tls12(Tls12ClientSessionValue),
}
impl ClientSessionValue {
fn common(&self) -> &ClientSessionCommon {
match self {
Self::Tls13(inner) => &inner.common,
#[cfg(feature = "tls12")]
Self::Tls12(inner) => &inner.common,
}
}
}
impl From<Tls13ClientSessionValue> for ClientSessionValue {
fn from(v: Tls13ClientSessionValue) -> Self {
Self::Tls13(v)
}
}
#[cfg(feature = "tls12")]
impl From<Tls12ClientSessionValue> for ClientSessionValue {
fn from(v: Tls12ClientSessionValue) -> Self {
Self::Tls12(v)
}
}
pub struct Retrieved<T> {
pub value: T,
retrieved_at: TimeBase,
@ -56,6 +26,13 @@ impl<T> Retrieved<T> {
retrieved_at,
}
}
pub fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
Some(Retrieved {
value: f(&self.value)?,
retrieved_at: self.retrieved_at,
})
}
}
impl Retrieved<&Tls13ClientSessionValue> {
@ -69,17 +46,9 @@ impl Retrieved<&Tls13ClientSessionValue> {
}
}
impl Retrieved<ClientSessionValue> {
pub fn tls13(&self) -> Option<Retrieved<&Tls13ClientSessionValue>> {
match &self.value {
ClientSessionValue::Tls13(value) => Some(Retrieved::new(value, self.retrieved_at)),
#[cfg(feature = "tls12")]
ClientSessionValue::Tls12(_) => None,
}
}
impl<T: std::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
pub fn has_expired(&self) -> bool {
let common = self.value.common();
let common = &*self.value;
common.lifetime_secs != 0
&& common
.epoch
@ -427,27 +396,8 @@ impl ServerSessionValue {
mod tests {
use super::*;
use crate::enums::*;
use crate::key::Certificate;
use crate::msgs::codec::{Codec, Reader};
use crate::ticketer::TimeBase;
use crate::tls13::TLS13_AES_128_GCM_SHA256;
#[test]
fn clientsessionvalue_is_debug() {
let csv = ClientSessionValue::from(Tls13ClientSessionValue::new(
TLS13_AES_128_GCM_SHA256
.tls13()
.unwrap(),
vec![],
vec![1, 2, 3],
vec![Certificate(b"abc".to_vec()), Certificate(b"def".to_vec())],
TimeBase::now().unwrap(),
15,
10,
128,
));
println!("{:?}", csv);
}
#[test]
fn serversessionvalue_is_debug() {