mirror of https://github.com/ctz/rustls
Make ClientSessionValue private
This commit is contained in:
parent
0b0c7b7a9d
commit
1538c460b8
|
@ -31,6 +31,7 @@ use crate::client::client_conn::ClientConnectionData;
|
|||
use crate::client::common::ClientHelloDetails;
|
||||
use crate::client::{tls13, ClientConfig, ServerName};
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(super) type NextState = Box<dyn State<ClientConnectionData>>;
|
||||
|
@ -41,19 +42,19 @@ fn find_session(
|
|||
server_name: &ServerName,
|
||||
config: &ClientConfig,
|
||||
#[cfg(feature = "quic")] cx: &mut ClientContext<'_>,
|
||||
) -> Option<persist::Retrieved<persist::ClientSessionValue>> {
|
||||
) -> Option<persist::Retrieved<ClientSessionValue>> {
|
||||
#[allow(clippy::let_and_return)]
|
||||
let found = config
|
||||
.session_storage
|
||||
.take_tls13_ticket(server_name)
|
||||
.map(persist::ClientSessionValue::from)
|
||||
.map(ClientSessionValue::Tls13)
|
||||
.or_else(|| {
|
||||
#[cfg(feature = "tls12")]
|
||||
{
|
||||
config
|
||||
.session_storage
|
||||
.tls12_session(server_name)
|
||||
.map(persist::ClientSessionValue::from)
|
||||
.map(ClientSessionValue::Tls12)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls12"))]
|
||||
|
@ -115,7 +116,7 @@ pub(super) fn start_handshake(
|
|||
|
||||
if let Some(_resuming) = &mut resuming_session {
|
||||
#[cfg(feature = "tls12")]
|
||||
if let persist::ClientSessionValue::Tls12(inner) = &mut _resuming.value {
|
||||
if let ClientSessionValue::Tls12(inner) = &mut _resuming.value {
|
||||
// If we have a ticket, we use the sessionid as a signal that
|
||||
// we're doing an abbreviated handshake. See section 3.4 in
|
||||
// RFC5077.
|
||||
|
@ -161,7 +162,7 @@ pub(super) fn start_handshake(
|
|||
|
||||
struct ExpectServerHello {
|
||||
config: Arc<ClientConfig>,
|
||||
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
|
||||
resuming_session: Option<persist::Retrieved<ClientSessionValue>>,
|
||||
server_name: ServerName,
|
||||
random: Random,
|
||||
using_ems: bool,
|
||||
|
@ -182,7 +183,7 @@ struct ExpectServerHelloOrHelloRetryRequest {
|
|||
fn emit_client_hello_for_retry(
|
||||
config: Arc<ClientConfig>,
|
||||
cx: &mut ClientContext<'_>,
|
||||
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
|
||||
resuming_session: Option<persist::Retrieved<ClientSessionValue>>,
|
||||
random: Random,
|
||||
using_ems: bool,
|
||||
mut transcript_buffer: HandshakeHashBuffer,
|
||||
|
@ -199,13 +200,9 @@ fn emit_client_hello_for_retry(
|
|||
// Do we have a SessionID or ticket cached for this host?
|
||||
let (ticket, resume_version) = if let Some(resuming) = &resuming_session {
|
||||
match &resuming.value {
|
||||
persist::ClientSessionValue::Tls13(inner) => {
|
||||
(inner.ticket().to_vec(), ProtocolVersion::TLSv1_3)
|
||||
}
|
||||
ClientSessionValue::Tls13(inner) => (inner.ticket().to_vec(), ProtocolVersion::TLSv1_3),
|
||||
#[cfg(feature = "tls12")]
|
||||
persist::ClientSessionValue::Tls12(inner) => {
|
||||
(inner.ticket().to_vec(), ProtocolVersion::TLSv1_2)
|
||||
}
|
||||
ClientSessionValue::Tls12(inner) => (inner.ticket().to_vec(), ProtocolVersion::TLSv1_2),
|
||||
}
|
||||
} else {
|
||||
(Vec::new(), ProtocolVersion::Unknown(0))
|
||||
|
@ -595,9 +592,9 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
|||
let resuming_session = self
|
||||
.resuming_session
|
||||
.and_then(|resuming| match resuming.value {
|
||||
persist::ClientSessionValue::Tls13(inner) => Some(inner),
|
||||
ClientSessionValue::Tls13(inner) => Some(inner),
|
||||
#[cfg(feature = "tls12")]
|
||||
persist::ClientSessionValue::Tls12(_) => None,
|
||||
ClientSessionValue::Tls12(_) => None,
|
||||
});
|
||||
|
||||
tls13::handle_server_hello(
|
||||
|
@ -621,8 +618,8 @@ impl State<ClientConnectionData> for ExpectServerHello {
|
|||
let resuming_session = self
|
||||
.resuming_session
|
||||
.and_then(|resuming| match resuming.value {
|
||||
persist::ClientSessionValue::Tls12(inner) => Some(inner),
|
||||
persist::ClientSessionValue::Tls13(_) => None,
|
||||
ClientSessionValue::Tls12(inner) => Some(inner),
|
||||
ClientSessionValue::Tls13(_) => None,
|
||||
});
|
||||
|
||||
tls12::CompleteServerHelloHandling {
|
||||
|
@ -811,3 +808,37 @@ impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum ClientSessionValue {
|
||||
Tls13(persist::Tls13ClientSessionValue),
|
||||
#[cfg(feature = "tls12")]
|
||||
Tls12(persist::Tls12ClientSessionValue),
|
||||
}
|
||||
|
||||
impl ClientSessionValue {
|
||||
fn common(&self) -> &persist::ClientSessionCommon {
|
||||
match self {
|
||||
Self::Tls13(inner) => &inner.common,
|
||||
#[cfg(feature = "tls12")]
|
||||
Self::Tls12(inner) => &inner.common,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl persist::Retrieved<ClientSessionValue> {
|
||||
fn tls13(&self) -> Option<persist::Retrieved<&persist::Tls13ClientSessionValue>> {
|
||||
self.map(|v| match v {
|
||||
ClientSessionValue::Tls13(v) => Some(v),
|
||||
#[cfg(feature = "tls12")]
|
||||
ClientSessionValue::Tls12(_) => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ClientSessionValue {
|
||||
type Target = persist::ClientSessionCommon;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.common()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,36 +14,6 @@ use std::cmp;
|
|||
#[cfg(feature = "tls12")]
|
||||
use std::mem;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ClientSessionValue {
|
||||
Tls13(Tls13ClientSessionValue),
|
||||
#[cfg(feature = "tls12")]
|
||||
Tls12(Tls12ClientSessionValue),
|
||||
}
|
||||
|
||||
impl ClientSessionValue {
|
||||
fn common(&self) -> &ClientSessionCommon {
|
||||
match self {
|
||||
Self::Tls13(inner) => &inner.common,
|
||||
#[cfg(feature = "tls12")]
|
||||
Self::Tls12(inner) => &inner.common,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Tls13ClientSessionValue> for ClientSessionValue {
|
||||
fn from(v: Tls13ClientSessionValue) -> Self {
|
||||
Self::Tls13(v)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls12")]
|
||||
impl From<Tls12ClientSessionValue> for ClientSessionValue {
|
||||
fn from(v: Tls12ClientSessionValue) -> Self {
|
||||
Self::Tls12(v)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Retrieved<T> {
|
||||
pub value: T,
|
||||
retrieved_at: TimeBase,
|
||||
|
@ -56,6 +26,13 @@ impl<T> Retrieved<T> {
|
|||
retrieved_at,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
|
||||
Some(Retrieved {
|
||||
value: f(&self.value)?,
|
||||
retrieved_at: self.retrieved_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Retrieved<&Tls13ClientSessionValue> {
|
||||
|
@ -69,17 +46,9 @@ impl Retrieved<&Tls13ClientSessionValue> {
|
|||
}
|
||||
}
|
||||
|
||||
impl Retrieved<ClientSessionValue> {
|
||||
pub fn tls13(&self) -> Option<Retrieved<&Tls13ClientSessionValue>> {
|
||||
match &self.value {
|
||||
ClientSessionValue::Tls13(value) => Some(Retrieved::new(value, self.retrieved_at)),
|
||||
#[cfg(feature = "tls12")]
|
||||
ClientSessionValue::Tls12(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
|
||||
pub fn has_expired(&self) -> bool {
|
||||
let common = self.value.common();
|
||||
let common = &*self.value;
|
||||
common.lifetime_secs != 0
|
||||
&& common
|
||||
.epoch
|
||||
|
@ -427,27 +396,8 @@ impl ServerSessionValue {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::enums::*;
|
||||
use crate::key::Certificate;
|
||||
use crate::msgs::codec::{Codec, Reader};
|
||||
use crate::ticketer::TimeBase;
|
||||
use crate::tls13::TLS13_AES_128_GCM_SHA256;
|
||||
|
||||
#[test]
|
||||
fn clientsessionvalue_is_debug() {
|
||||
let csv = ClientSessionValue::from(Tls13ClientSessionValue::new(
|
||||
TLS13_AES_128_GCM_SHA256
|
||||
.tls13()
|
||||
.unwrap(),
|
||||
vec![],
|
||||
vec![1, 2, 3],
|
||||
vec![Certificate(b"abc".to_vec()), Certificate(b"def".to_vec())],
|
||||
TimeBase::now().unwrap(),
|
||||
15,
|
||||
10,
|
||||
128,
|
||||
));
|
||||
println!("{:?}", csv);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serversessionvalue_is_debug() {
|
||||
|
|
Loading…
Reference in New Issue