diff --git a/rustls/src/server/hs.rs b/rustls/src/server/hs.rs index 6614b5df..5f7b6621 100644 --- a/rustls/src/server/hs.rs +++ b/rustls/src/server/hs.rs @@ -1,4 +1,5 @@ use crate::error::TlsError; +use crate::key::Certificate; use crate::kx; #[cfg(feature = "logging")] use crate::log::{debug, trace}; @@ -31,6 +32,8 @@ use webpki; use crate::server::common::{HandshakeDetails, ServerKXDetails}; use crate::server::{tls12, tls13}; +use std::sync::Arc; + pub type NextState = Box; pub type NextStateOrError = Result; @@ -150,7 +153,8 @@ impl ExtensionProcessing { pub(crate) fn process_common( &mut self, sess: &mut ServerSessionImpl, - server_key: Option<&mut sign::ActiveCertifiedKey>, + ocsp_response: &mut Option<&[u8]>, + sct_list: &mut Option<&[u8]>, hello: &ClientHelloPayload, resumedata: Option<&persist::ServerSessionValue>, handshake: &HandshakeDetails, @@ -223,44 +227,40 @@ impl ExtensionProcessing { .push(ServerExtension::ServerNameAck); } - if let Some(server_key) = server_key { - // Send status_request response if we have one. This is not allowed - // if we're resuming, and is only triggered if we have an OCSP response - // to send. - if !for_resume - && hello - .find_extension(ExtensionType::StatusRequest) - .is_some() - { - if server_key.has_ocsp() && !sess.common.is_tls13() { - // Only TLS1.2 sends confirmation in ServerHello - self.exts - .push(ServerExtension::CertificateStatusAck); - } - } else { - // Throw away any OCSP response so we don't try to send it later. - drop(server_key.take_ocsp()); + // Send status_request response if we have one. This is not allowed + // if we're resuming, and is only triggered if we have an OCSP response + // to send. + if !for_resume + && hello + .find_extension(ExtensionType::StatusRequest) + .is_some() + { + if ocsp_response.is_some() && !sess.common.is_tls13() { + // Only TLS1.2 sends confirmation in ServerHello + self.exts + .push(ServerExtension::CertificateStatusAck); } + } else { + // Throw away any OCSP response so we don't try to send it later. + ocsp_response.take(); + } - if !for_resume - && hello - .find_extension(ExtensionType::SCT) - .is_some() - { - if !sess.common.is_tls13() { - // Take the SCT list, if any, so we don't send it later, - // and put it in the legacy extension. - server_key - .take_sct_list() - .map(|sct_list| { - self.exts - .push(ServerExtension::make_sct(sct_list.to_vec())) - }); + if !for_resume + && hello + .find_extension(ExtensionType::SCT) + .is_some() + { + if !sess.common.is_tls13() { + // Take the SCT list, if any, so we don't send it later, + // and put it in the legacy extension. + if let Some(sct_list) = sct_list.take() { + self.exts + .push(ServerExtension::make_sct(sct_list.to_vec())); } - } else { - // Throw away any SCT list so we don't send it later. - drop(server_key.take_sct_list()); } + } else { + // Throw away any SCT list so we don't send it later. + sct_list.take(); } self.exts @@ -395,13 +395,14 @@ impl ExpectClientHello { fn emit_server_hello( &mut self, sess: &mut ServerSessionImpl, - server_key: Option<&mut sign::ActiveCertifiedKey>, + ocsp_response: &mut Option<&[u8]>, + sct_list: &mut Option<&[u8]>, hello: &ClientHelloPayload, resumedata: Option<&persist::ServerSessionValue>, randoms: &SessionRandoms, ) -> Result<(), TlsError> { let mut ep = ExtensionProcessing::new(); - ep.process_common(sess, server_key, hello, resumedata, &self.handshake)?; + ep.process_common(sess, ocsp_response, sct_list, hello, resumedata, &self.handshake)?; ep.process_tls12(sess, hello, self.using_ems); self.send_ticket = ep.send_ticket; @@ -433,10 +434,8 @@ impl ExpectClientHello { fn emit_certificate( &mut self, sess: &mut ServerSessionImpl, - server_certkey: &sign::ActiveCertifiedKey, + cert_chain: &[Certificate], ) { - let cert_chain = server_certkey.get_cert(); - let c = Message { typ: ContentType::Handshake, version: ProtocolVersion::TLSv1_2, @@ -455,13 +454,8 @@ impl ExpectClientHello { fn emit_cert_status( &mut self, sess: &mut ServerSessionImpl, - server_certkey: &mut sign::ActiveCertifiedKey, + ocsp: &[u8], ) { - let ocsp = match server_certkey.take_ocsp() { - Some(ocsp) => ocsp, - None => return, - }; - let st = CertificateStatus::new(ocsp.to_owned()); let c = Message { @@ -484,7 +478,7 @@ impl ExpectClientHello { sess: &mut ServerSessionImpl, sigschemes: Vec, skxg: &'static kx::SupportedKxGroup, - server_certkey: &sign::ActiveCertifiedKey, + signing_key: &Arc>, randoms: &SessionRandoms, ) -> Result { let kx = kx::KeyExchange::start(skxg) @@ -496,7 +490,6 @@ impl ExpectClientHello { msg.extend(&randoms.server); secdh.encode(&mut msg); - let signing_key = server_certkey.get_key(); let signer = signing_key .choose_scheme(&sigschemes) .ok_or_else(|| TlsError::General("incompatible signing key".to_string()))?; @@ -600,7 +593,7 @@ impl ExpectClientHello { } self.handshake.session_id = *id; - self.emit_server_hello(sess, None, client_hello, Some(&resumedata), randoms)?; + self.emit_server_hello(sess, &mut None, &mut None, client_hello, Some(&resumedata), randoms)?; let suite = sess.common.get_suite_assert(); let secrets = SessionSecrets::new_resume(&randoms, suite, &resumedata.master_secret.0); @@ -762,12 +755,11 @@ impl State for ExpectClientHello { TlsError::General("no server certificate chain resolved".to_string()) })? }; - let mut certkey = sign::ActiveCertifiedKey::from_certified_key(certkey.as_ref()); // Reduce our supported ciphersuites by the certificate. // (no-op for TLS1.3) let suitable_suites = - suites::reduce_given_sigalg(&sess.config.ciphersuites, certkey.get_key().algorithm()); + suites::reduce_given_sigalg(&sess.config.ciphersuites, certkey.key.algorithm()); // And version let suitable_suites = suites::reduce_given_version(&suitable_suites, version); @@ -814,7 +806,7 @@ impl State for ExpectClientHello { if sess.common.is_tls13() { return self .into_complete_tls13_client_hello_handling(randoms) - .handle_client_hello(ciphersuite, sess, certkey, &m); + .handle_client_hello(ciphersuite, sess, &certkey, &m); } // -- TLS1.2 only from hereon in -- @@ -944,10 +936,13 @@ impl State for ExpectClientHello { debug_assert_eq!(ecpoint, ECPointFormat::Uncompressed); - self.emit_server_hello(sess, Some(&mut certkey), client_hello, None, &randoms)?; - self.emit_certificate(sess, &certkey); - self.emit_cert_status(sess, &mut certkey); - let kx = self.emit_server_kx(sess, sigschemes, group, &certkey, &randoms)?; + let (mut ocsp_response, mut sct_list) = (certkey.ocsp.as_deref(), certkey.sct_list.as_deref()); + self.emit_server_hello(sess, &mut ocsp_response, &mut sct_list, client_hello, None, &randoms)?; + self.emit_certificate(sess, &certkey.cert); + if let Some(ocsp_response) = ocsp_response { + self.emit_cert_status(sess, ocsp_response); + } + let kx = self.emit_server_kx(sess, sigschemes, group, &certkey.key, &randoms)?; let doing_client_auth = self.emit_certificate_req(sess)?; self.emit_server_hello_done(sess); diff --git a/rustls/src/server/tls13.rs b/rustls/src/server/tls13.rs index 9cecd268..ed745b11 100644 --- a/rustls/src/server/tls13.rs +++ b/rustls/src/server/tls13.rs @@ -1,6 +1,7 @@ use crate::check::check_message; use crate::{cipher, SupportedCipherSuite}; use crate::error::TlsError; +use crate::key::Certificate; use crate::key_schedule::{ KeyScheduleEarly, KeyScheduleHandshake, KeyScheduleNonSecret, KeyScheduleTraffic, KeyScheduleTrafficWithClientFinishedPending, @@ -48,6 +49,8 @@ use crate::server::hs; use ring::constant_time; +use std::sync::Arc; + pub struct CompleteClientHelloHandling { pub handshake: HandshakeDetails, pub randoms: SessionRandoms, @@ -275,12 +278,13 @@ impl CompleteClientHelloHandling { fn emit_encrypted_extensions( &mut self, sess: &mut ServerSessionImpl, - server_key: &mut sign::ActiveCertifiedKey, + ocsp_response: &mut Option<&[u8]>, + sct_list: &mut Option<&[u8]>, hello: &ClientHelloPayload, resumedata: Option<&persist::ServerSessionValue>, ) -> Result<(), TlsError> { let mut ep = hs::ExtensionProcessing::new(); - ep.process_common(sess, Some(server_key), hello, resumedata, &self.handshake)?; + ep.process_common(sess, ocsp_response, sct_list, hello, resumedata, &self.handshake)?; let ee = Message { typ: ContentType::Handshake, @@ -355,10 +359,12 @@ impl CompleteClientHelloHandling { fn emit_certificate_tls13( &mut self, sess: &mut ServerSessionImpl, - server_key: &mut sign::ActiveCertifiedKey, + cert_chain: &[Certificate], + ocsp_response: Option<&[u8]>, + sct_list: Option<&[u8]>, ) { let mut cert_entries = vec![]; - for cert in server_key.get_cert() { + for cert in cert_chain { let entry = CertificateEntry { cert: cert.to_owned(), exts: Vec::new(), @@ -370,7 +376,7 @@ impl CompleteClientHelloHandling { if let Some(end_entity_cert) = cert_entries.first_mut() { // Apply OCSP response to first certificate (we don't support OCSP // except for leaf certs). - if let Some(ocsp) = server_key.take_ocsp() { + if let Some(ocsp) = ocsp_response { let cst = CertificateStatus::new(ocsp.to_owned()); end_entity_cert .exts @@ -378,7 +384,7 @@ impl CompleteClientHelloHandling { } // Likewise, SCT - if let Some(sct_list) = server_key.take_sct_list() { + if let Some(sct_list) = sct_list { end_entity_cert .exts .push(CertificateExtension::make_sct(sct_list.to_owned())); @@ -405,7 +411,7 @@ impl CompleteClientHelloHandling { fn emit_certificate_verify_tls13( &mut self, sess: &mut ServerSessionImpl, - server_key: &sign::ActiveCertifiedKey, + signing_key: &Arc>, schemes: &[SignatureScheme], ) -> Result<(), TlsError> { let message = verify::construct_tls13_server_verify_message( @@ -415,7 +421,6 @@ impl CompleteClientHelloHandling { .get_current_hash(), ); - let signing_key = server_key.get_key(); let signer = signing_key .choose_scheme(schemes) .ok_or_else(|| hs::incompatible(sess, "no overlapping sigschemes"))?; @@ -532,7 +537,7 @@ impl CompleteClientHelloHandling { mut self, suite: &'static SupportedCipherSuite, sess: &mut ServerSessionImpl, - mut server_key: sign::ActiveCertifiedKey, + server_key: &sign::CertifiedKey, chm: &Message, ) -> hs::NextStateOrError { let client_hello = require_handshake_msg!( @@ -669,12 +674,14 @@ impl CompleteClientHelloHandling { if !self.done_retry { self.emit_fake_ccs(sess); } - self.emit_encrypted_extensions(sess, &mut server_key, client_hello, resumedata.as_ref())?; + + let (mut ocsp_response, mut sct_list) = (server_key.ocsp.as_deref(), server_key.sct_list.as_deref()); + self.emit_encrypted_extensions(sess, &mut ocsp_response, &mut sct_list, client_hello, resumedata.as_ref())?; let doing_client_auth = if full_handshake { let client_auth = self.emit_certificate_req_tls13(sess)?; - self.emit_certificate_tls13(sess, &mut server_key); - self.emit_certificate_verify_tls13(sess, &mut server_key, &sigschemes_ext)?; + self.emit_certificate_tls13(sess, &server_key.cert, ocsp_response, sct_list); + self.emit_certificate_verify_tls13(sess, &server_key.key, &sigschemes_ext)?; client_auth } else { false diff --git a/rustls/src/sign.rs b/rustls/src/sign.rs index 983993e7..3a70e1e3 100644 --- a/rustls/src/sign.rs +++ b/rustls/src/sign.rs @@ -125,52 +125,6 @@ impl CertifiedKey { } } -/// ActiveCertifiedKey wraps CertifiedKey and prevents `ocsp` and `sct_list` from being -/// consumed more than once. -pub(crate) struct ActiveCertifiedKey<'a> { - key: &'a CertifiedKey, - ocsp: Option<&'a [u8]>, - sct_list: Option<&'a [u8]>, -} - -impl<'a> ActiveCertifiedKey<'a> { - pub fn from_certified_key<'k>(key: &'k CertifiedKey) -> ActiveCertifiedKey<'k> { - ActiveCertifiedKey { - key, - ocsp: key.ocsp.as_deref(), - sct_list: key.sct_list.as_deref(), - } - } - - /// Return true if there's an OCSP response. - #[inline] - pub fn has_ocsp(&self) -> bool { - self.ocsp.is_some() - } - - /// Get the certificate chain - #[inline] - pub fn get_cert(&self) -> &[key::Certificate] { - &self.key.cert - } - - /// Get the signing key - #[inline] - pub fn get_key(&self) -> &Arc> { - &self.key.key - } - - #[inline] - pub fn take_ocsp(&mut self) -> Option<&[u8]> { - self.ocsp.take() - } - - #[inline] - pub fn take_sct_list(&mut self) -> Option<&[u8]> { - self.sct_list.take() - } -} - /// Parse `der` as any supported key encoding/type, returning /// the first which works. pub fn any_supported_type(der: &key::PrivateKey) -> Result, ()> {