From 2c643313f0128f9d3256a842a29c5ad02d6f37a0 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Mon, 8 Mar 2021 12:55:26 +0100 Subject: [PATCH] Clarify availability of client auth state in HandshakeHash --- rustls/src/client/tls12.rs | 3 +-- rustls/src/hash_hs.rs | 39 ++++++++++++++++++-------------------- rustls/src/server/tls12.rs | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/rustls/src/client/tls12.rs b/rustls/src/client/tls12.rs index a35caac4..d4cc1c19 100644 --- a/rustls/src/client/tls12.rs +++ b/rustls/src/client/tls12.rs @@ -270,8 +270,7 @@ fn emit_certverify( } }; - let message = transcript - .take_handshake_buf(); + let message = transcript.take_handshake_buf().unwrap(); let scheme = signer.get_scheme(); let sig = signer.sign(&message)?; let body = DigitallySignedStruct::new(scheme, sig); diff --git a/rustls/src/hash_hs.rs b/rustls/src/hash_hs.rs index 364616a5..80f4b155 100644 --- a/rustls/src/hash_hs.rs +++ b/rustls/src/hash_hs.rs @@ -59,8 +59,10 @@ impl HandshakeHashBuffer { HandshakeHash { ctx, - client_auth_enabled: self.client_auth_enabled, - buffer: self.buffer, + client_auth: match self.client_auth_enabled { + true => Some(self.buffer), + false => None, + } } } } @@ -77,18 +79,14 @@ pub struct HandshakeHash { ctx: digest::Context, /// true if we need to keep all messages - client_auth_enabled: bool, - - /// buffer for client-auth. - buffer: Vec, + client_auth: Option>, } impl HandshakeHash { /// We decided not to do client auth after all, so discard /// the transcript. pub fn abandon_client_auth(&mut self) { - self.client_auth_enabled = false; - self.buffer.drain(..); + self.client_auth = None; } /// Hash/buffer a handshake message. @@ -107,8 +105,8 @@ impl HandshakeHash { fn update_raw(&mut self, buf: &[u8]) -> &mut Self { self.ctx.update(buf); - if self.client_auth_enabled { - self.buffer.extend_from_slice(buf); + if let Some(buffer) = &mut self.client_auth { + buffer.extend_from_slice(buf); } self @@ -128,7 +126,7 @@ impl HandshakeHash { HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref()); HandshakeHashBuffer { - client_auth_enabled: self.client_auth_enabled, + client_auth_enabled: self.client_auth.is_some(), buffer: old_handshake_hash_msg.get_encoding(), } } @@ -155,9 +153,8 @@ impl HandshakeHash { /// Takes this object's buffer containing all handshake messages /// so far. This method only works once; it resets the buffer /// to empty. - pub fn take_handshake_buf(&mut self) -> Vec { - debug_assert!(self.client_auth_enabled); - mem::replace(&mut self.buffer, Vec::new()) + pub fn take_handshake_buf(&mut self) -> Option> { + self.client_auth.take() } /// The digest algorithm @@ -177,7 +174,7 @@ mod test { hhb.update_raw(b"hello"); assert_eq!(hhb.buffer.len(), 5); let mut hh = hhb.start_hash(&digest::SHA256); - assert_eq!(hh.buffer.len(), 0); + assert!(hh.client_auth.is_none()); hh.update_raw(b"world"); let h = hh.get_current_hash(); let h = h.as_ref(); @@ -194,9 +191,9 @@ mod test { hhb.update_raw(b"hello"); assert_eq!(hhb.buffer.len(), 5); let mut hh = hhb.start_hash(&digest::SHA256); - assert_eq!(hh.buffer.len(), 5); + assert_eq!(hh.client_auth.as_ref().map(|buf| buf.len()), Some(5)); hh.update_raw(b"world"); - assert_eq!(hh.buffer.len(), 10); + assert_eq!(hh.client_auth.as_ref().map(|buf| buf.len()), Some(10)); let h = hh.get_current_hash(); let h = h.as_ref(); assert_eq!(h[0], 0x93); @@ -204,7 +201,7 @@ mod test { assert_eq!(h[2], 0x18); assert_eq!(h[3], 0x5c); let buf = hh.take_handshake_buf(); - assert_eq!(b"helloworld".to_vec(), buf); + assert_eq!(Some(b"helloworld".to_vec()), buf); } #[test] @@ -214,11 +211,11 @@ mod test { hhb.update_raw(b"hello"); assert_eq!(hhb.buffer.len(), 5); let mut hh = hhb.start_hash(&digest::SHA256); - assert_eq!(hh.buffer.len(), 5); + assert_eq!(hh.client_auth.as_ref().map(|buf| buf.len()), Some(5)); hh.abandon_client_auth(); - assert_eq!(hh.buffer.len(), 0); + assert_eq!(hh.client_auth, None); hh.update_raw(b"world"); - assert_eq!(hh.buffer.len(), 0); + assert_eq!(hh.client_auth, None); let h = hh.get_current_hash(); let h = h.as_ref(); assert_eq!(h[0], 0x93); diff --git a/rustls/src/server/tls12.rs b/rustls/src/server/tls12.rs index 398352f8..d6c3b5f9 100644 --- a/rustls/src/server/tls12.rs +++ b/rustls/src/server/tls12.rs @@ -183,7 +183,7 @@ impl hs::State for ExpectCertificateVerify { HandshakeType::CertificateVerify, HandshakePayload::CertificateVerify )?; - let handshake_msgs = self.transcript.take_handshake_buf(); + let handshake_msgs = self.transcript.take_handshake_buf().unwrap(); let certs = &self.client_cert.cert_chain; sess.config