diff --git a/rustls/src/client/tls13.rs b/rustls/src/client/tls13.rs index 68c86056..9e6f1c15 100644 --- a/rustls/src/client/tls13.rs +++ b/rustls/src/client/tls13.rs @@ -889,7 +889,6 @@ impl State for ExpectFinished { suite: st.suite, transcript: st.transcript, key_schedule: key_schedule_traffic, - want_write_key_update: false, _cert_verified: st.cert_verified, _sig_verified: st.sig_verified, _fin_verified: fin, @@ -913,7 +912,6 @@ struct ExpectTraffic { suite: &'static Tls13CipherSuite, transcript: HandshakeHash, key_schedule: KeyScheduleTraffic, - want_write_key_update: bool, _cert_verified: verify::ServerCertVerified, _sig_verified: verify::HandshakeSignatureValid, _fin_verified: verify::FinishedMessageVerified, @@ -983,7 +981,7 @@ impl ExpectTraffic { fn handle_key_update( &mut self, common: &mut CommonState, - kur: &KeyUpdateRequest, + key_update_request: &KeyUpdateRequest, ) -> Result<(), Error> { #[cfg(feature = "quic")] { @@ -997,15 +995,9 @@ impl ExpectTraffic { // Mustn't be interleaved with other handshake messages. common.check_aligned_handshake()?; - match kur { - KeyUpdateRequest::UpdateNotRequested => {} - KeyUpdateRequest::UpdateRequested => { - self.want_write_key_update = true; - } - _ => { - common.send_fatal_alert(AlertDescription::IllegalParameter); - return Err(Error::CorruptMessagePayload(ContentType::Handshake)); - } + if common.should_update_key(key_update_request)? { + self.key_schedule + .update_encrypter_and_notify(common); } // Update our read-side keys. @@ -1059,14 +1051,6 @@ impl State for ExpectTraffic { .export_keying_material(output, label, context) } - fn perhaps_write_key_update(&mut self, common: &mut CommonState) { - if self.want_write_key_update { - self.want_write_key_update = false; - self.key_schedule - .update_encrypter_and_notify(common); - } - } - #[cfg(feature = "secret_extraction")] fn extract_secrets(&self) -> Result { self.key_schedule diff --git a/rustls/src/conn.rs b/rustls/src/conn.rs index 5ef5c2fd..a89ecfe8 100644 --- a/rustls/src/conn.rs +++ b/rustls/src/conn.rs @@ -6,8 +6,8 @@ use crate::log::{debug, error, trace, warn}; use crate::msgs::alert::AlertMessagePayload; use crate::msgs::base::Payload; use crate::msgs::deframer::{Deframed, MessageDeframer}; -use crate::msgs::enums::HandshakeType; use crate::msgs::enums::{AlertDescription, AlertLevel, ContentType}; +use crate::msgs::enums::{HandshakeType, KeyUpdateRequest}; use crate::msgs::fragmenter::MessageFragmenter; use crate::msgs::handshake::Random; use crate::msgs::message::{ @@ -694,9 +694,8 @@ impl ConnectionCommon { } pub(crate) fn send_some_plaintext(&mut self, buf: &[u8]) -> usize { - if let Ok(st) = &mut self.state { - st.perhaps_write_key_update(&mut self.common_state); - } + self.common_state + .perhaps_write_key_update(); self.common_state .send_some_plaintext(buf) } @@ -823,6 +822,8 @@ pub struct CommonState { received_plaintext: ChunkVecBuffer, sendable_plaintext: ChunkVecBuffer, pub(crate) sendable_tls: ChunkVecBuffer, + queued_key_update_message: Option>, + #[allow(dead_code)] // only read for QUIC /// Protocol whose key schedule should be used. Unused for TLS < 1.3. pub(crate) protocol: Protocol, @@ -853,6 +854,7 @@ impl CommonState { received_plaintext: ChunkVecBuffer::new(Some(DEFAULT_RECEIVED_PLAINTEXT_LIMIT)), sendable_plaintext: ChunkVecBuffer::new(Some(DEFAULT_BUFFER_LIMIT)), sendable_tls: ChunkVecBuffer::new(Some(DEFAULT_BUFFER_LIMIT)), + queued_key_update_message: None, protocol: Protocol::Tcp, #[cfg(feature = "quic")] @@ -1317,6 +1319,35 @@ impl CommonState { #[cfg(not(feature = "quic"))] false } + + pub(crate) fn should_update_key( + &mut self, + key_update_request: &KeyUpdateRequest, + ) -> Result { + match key_update_request { + KeyUpdateRequest::UpdateNotRequested => Ok(false), + KeyUpdateRequest::UpdateRequested => Ok(self.queued_key_update_message.is_none()), + _ => { + self.send_fatal_alert(AlertDescription::IllegalParameter); + Err(Error::CorruptMessagePayload(ContentType::Handshake)) + } + } + } + + pub(crate) fn enqueue_key_update_notification(&mut self) { + let message = PlainMessage::from(Message::build_key_update_notify()); + self.queued_key_update_message = Some( + self.record_layer + .encrypt_outgoing(message.borrow()) + .encode(), + ); + } + + fn perhaps_write_key_update(&mut self) { + if let Some(message) = self.queued_key_update_message.take() { + self.sendable_tls.append(message); + } + } } pub(crate) trait State: Send + Sync { @@ -1339,8 +1370,6 @@ pub(crate) trait State: Send + Sync { fn extract_secrets(&self) -> Result { Err(Error::HandshakeNotComplete) } - - fn perhaps_write_key_update(&mut self, _cx: &mut CommonState) {} } pub(crate) struct Context<'a, Data> { diff --git a/rustls/src/server/tls13.rs b/rustls/src/server/tls13.rs index bb63c3e5..2ab36cdb 100644 --- a/rustls/src/server/tls13.rs +++ b/rustls/src/server/tls13.rs @@ -1194,7 +1194,6 @@ impl State for ExpectFinished { Ok(Box::new(ExpectTraffic { key_schedule: key_schedule_traffic, - want_write_key_update: false, _fin_verified: fin, })) } @@ -1203,7 +1202,6 @@ impl State for ExpectFinished { // --- Process traffic --- struct ExpectTraffic { key_schedule: KeyScheduleTraffic, - want_write_key_update: bool, _fin_verified: verify::FinishedMessageVerified, } @@ -1211,7 +1209,7 @@ impl ExpectTraffic { fn handle_key_update( &mut self, common: &mut CommonState, - kur: &KeyUpdateRequest, + key_update_request: &KeyUpdateRequest, ) -> Result<(), Error> { #[cfg(feature = "quic")] { @@ -1224,15 +1222,9 @@ impl ExpectTraffic { common.check_aligned_handshake()?; - match kur { - KeyUpdateRequest::UpdateNotRequested => {} - KeyUpdateRequest::UpdateRequested => { - self.want_write_key_update = true; - } - _ => { - common.send_fatal_alert(AlertDescription::IllegalParameter); - return Err(Error::CorruptMessagePayload(ContentType::Handshake)); - } + if common.should_update_key(key_update_request)? { + self.key_schedule + .update_encrypter_and_notify(common); } // Update our read-side keys. @@ -1278,14 +1270,6 @@ impl State for ExpectTraffic { .export_keying_material(output, label, context) } - fn perhaps_write_key_update(&mut self, common: &mut CommonState) { - if self.want_write_key_update { - self.want_write_key_update = false; - self.key_schedule - .update_encrypter_and_notify(common); - } - } - #[cfg(feature = "secret_extraction")] fn extract_secrets(&self) -> Result { self.key_schedule diff --git a/rustls/src/tls13/key_schedule.rs b/rustls/src/tls13/key_schedule.rs index d761a894..b5e70d7c 100644 --- a/rustls/src/tls13/key_schedule.rs +++ b/rustls/src/tls13/key_schedule.rs @@ -2,7 +2,6 @@ use crate::cipher::{Iv, IvLen, MessageDecrypter}; use crate::conn::{CommonState, Side}; use crate::error::Error; use crate::msgs::base::PayloadU8; -use crate::msgs::message::Message; #[cfg(feature = "quic")] use crate::quic; #[cfg(feature = "secret_extraction")] @@ -464,7 +463,7 @@ impl KeyScheduleTraffic { pub(crate) fn update_encrypter_and_notify(&mut self, common: &mut CommonState) { let secret = self.next_application_traffic_secret(common.side); - common.send_msg_encrypt(Message::build_key_update_notify().into()); + common.enqueue_key_update_notification(); self.ks.set_encrypter(&secret, common); }