diff --git a/rustls/src/check.rs b/rustls/src/check.rs new file mode 100644 index 00000000..2789cf8b --- /dev/null +++ b/rustls/src/check.rs @@ -0,0 +1,75 @@ +use crate::msgs::enums::{ContentType, HandshakeType}; +use crate::msgs::message::{Message, MessagePayload}; +use crate::error::TLSError; +#[cfg(feature = "logging")] +use crate::log::warn; + +/// For a Message $m, and a HandshakePayload enum member $payload_type, +/// return Ok(payload) if $m is both a handshake message and one that +/// has the given $payload_type. If not, return Err(TLSError) quoting +/// $handshake_type as the expected handshake type. +macro_rules! require_handshake_msg( + ( $m:expr, $handshake_type:path, $payload_type:path ) => ( + match $m.payload { + MessagePayload::Handshake(ref hsp) => match hsp.payload { + $payload_type(ref hm) => Ok(hm), + _ => Err(TLSError::InappropriateHandshakeMessage { + expect_types: vec![ $handshake_type ], + got_type: hsp.typ}) + } + _ => Err(TLSError::InappropriateMessage { + expect_types: vec![ ContentType::Handshake ], + got_type: $m.typ}) + } + ) +); + +/// Like require_handshake_msg, but moves the payload out of $m. +macro_rules! require_handshake_msg_mut( + ( $m:expr, $handshake_type:path, $payload_type:path ) => ( + match $m.payload { + MessagePayload::Handshake(hsp) => match hsp.payload { + $payload_type(hm) => Ok(hm), + _ => Err(TLSError::InappropriateHandshakeMessage { + expect_types: vec![ $handshake_type ], + got_type: hsp.typ}) + } + _ => Err(TLSError::InappropriateMessage { + expect_types: vec![ ContentType::Handshake ], + got_type: $m.typ}) + } + ) +); + +/// Validate the message `m`: return an error if: +/// +/// - the type of m does not appear in `content_types`. +/// - if m is a handshake message, the handshake message type does +/// not appear in `handshake_types`. +pub fn check_message(m: &Message, + content_types: &[ContentType], + handshake_types: &[HandshakeType]) -> Result<(), TLSError> { + if !content_types.contains(&m.typ) { + warn!("Received a {:?} message while expecting {:?}", + m.typ, + content_types); + return Err(TLSError::InappropriateMessage { + expect_types: content_types.to_vec(), + got_type: m.typ, + }); + } + + if let MessagePayload::Handshake(ref hsp) = m.payload { + if !handshake_types.is_empty() && !handshake_types.contains(&hsp.typ) { + warn!("Received a {:?} handshake message while expecting {:?}", + hsp.typ, + handshake_types); + return Err(TLSError::InappropriateHandshakeMessage { + expect_types: handshake_types.to_vec(), + got_type: hsp.typ, + }); + } + } + + Ok(()) +} diff --git a/rustls/src/client/hs.rs b/rustls/src/client/hs.rs index ef28fa43..66e9cef7 100644 --- a/rustls/src/client/hs.rs +++ b/rustls/src/client/hs.rs @@ -22,8 +22,8 @@ use crate::rand; use crate::ticketer; #[cfg(feature = "logging")] use crate::log::{debug, trace}; +use crate::check::check_message; use crate::error::TLSError; -use crate::handshake::check_handshake_message; #[cfg(feature = "quic")] use crate::msgs::base::PayloadU16; @@ -33,36 +33,12 @@ use crate::client::{tls12, tls13}; use webpki; -macro_rules! extract_handshake( - ( $m:expr, $t:path ) => ( - match $m.payload { - MessagePayload::Handshake(ref hsp) => match hsp.payload { - $t(ref hm) => Some(hm), - _ => None - }, - _ => None - } - ) -); - -macro_rules! extract_handshake_mut( - ( $m:expr, $t:path ) => ( - match $m.payload { - MessagePayload::Handshake(hsp) => match hsp.payload { - $t(hm) => Some(hm), - _ => None - }, - _ => None - } - ) -); - -pub type CheckResult = Result<(), TLSError>; pub type NextState = Box; pub type NextStateOrError = Result; pub trait State { - fn check_message(&self, m: &Message) -> CheckResult; + /// Each handle() implementation consumes a whole TLS message, and returns + /// either an error or the next state. fn handle(self: Box, sess: &mut ClientSessionImpl, m: Message) -> NextStateOrError; fn export_keying_material(&self, @@ -441,12 +417,8 @@ impl ExpectServerHello { } impl State for ExpectServerHello { - fn check_message(&self, m: &Message) -> CheckResult { - check_handshake_message(m, &[HandshakeType::ServerHello]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> NextStateOrError { - let server_hello = extract_handshake!(m, HandshakePayload::ServerHello).unwrap(); + let server_hello = require_handshake_msg!(m, HandshakeType::ServerHello, HandshakePayload::ServerHello)?; trace!("We got ServerHello {:#?}", server_hello); use crate::ProtocolVersion::{TLSv1_2, TLSv1_3}; @@ -540,10 +512,10 @@ impl State for ExpectServerHello { // For TLS1.3, start message encryption using // handshake_traffic_secret. if sess.common.is_tls13() { - tls13::validate_server_hello(sess, server_hello)?; + tls13::validate_server_hello(sess, &server_hello)?; let key_schedule = tls13::start_handshake_traffic(sess, self.early_key_schedule.take(), - server_hello, + &server_hello, &mut self.handshake, &mut self.hello)?; tls13::emit_fake_ccs(&mut self.handshake, sess); @@ -641,9 +613,7 @@ impl ExpectServerHelloOrHelloRetryRequest { } fn handle_hello_retry_request(mut self, sess: &mut ClientSessionImpl, m: Message) -> NextStateOrError { - check_handshake_message(&m, &[HandshakeType::HelloRetryRequest])?; - - let hrr = extract_handshake!(m, HandshakePayload::HelloRetryRequest).unwrap(); + let hrr = require_handshake_msg!(m, HandshakeType::HelloRetryRequest, HandshakePayload::HelloRetryRequest)?; trace!("Got HRR {:?}", hrr); check_aligned_handshake(sess)?; @@ -721,18 +691,15 @@ impl ExpectServerHelloOrHelloRetryRequest { Ok(emit_client_hello_for_retry(sess, self.0.handshake, self.0.hello, - Some(hrr))) + Some(&hrr))) } } impl State for ExpectServerHelloOrHelloRetryRequest { - fn check_message(&self, m: &Message) -> CheckResult { - check_handshake_message(m, - &[HandshakeType::ServerHello, - HandshakeType::HelloRetryRequest]) - } - fn handle(self: Box, sess: &mut ClientSessionImpl, m: Message) -> NextStateOrError { + check_message(&m, + &[ContentType::Handshake], + &[HandshakeType::ServerHello, HandshakeType::HelloRetryRequest])?; if m.is_handshake_type(HandshakeType::ServerHello) { self.into_expect_server_hello().handle(sess, m) } else { diff --git a/rustls/src/client/mod.rs b/rustls/src/client/mod.rs index 6c086deb..df7b60e5 100644 --- a/rustls/src/client/mod.rs +++ b/rustls/src/client/mod.rs @@ -371,7 +371,7 @@ pub struct ClientSessionImpl { pub alpn_protocol: Option>, pub common: SessionCommon, pub error: Option, - pub state: Option>, + pub state: Option, pub server_cert_chain: CertificatePayload, pub early_data: EarlyData, pub resumption_ciphersuite: Option<&'static SupportedCipherSuite>, @@ -495,13 +495,24 @@ impl ClientSessionImpl { Ok(()) } + fn reject_renegotiation_attempt(&mut self) -> Result<(), TLSError> { + self.common.send_warning_alert(AlertDescription::NoRenegotiation); + Ok(()) + } + fn queue_unexpected_alert(&mut self) { self.common.send_fatal_alert(AlertDescription::UnexpectedMessage); } - fn reject_renegotiation_attempt(&mut self) -> Result<(), TLSError> { - self.common.send_warning_alert(AlertDescription::NoRenegotiation); - Ok(()) + fn maybe_send_unexpected_alert(&mut self, rc: hs::NextStateOrError) -> hs::NextStateOrError { + match rc { + Err(TLSError::InappropriateMessage { .. }) | + Err(TLSError::InappropriateHandshakeMessage { .. }) => { + self.queue_unexpected_alert(); + } + _ => {} + }; + rc } /// Process `msg`. First, we get the current state. Then we ask what messages @@ -517,13 +528,9 @@ impl ClientSessionImpl { } let state = self.state.take().unwrap(); - state - .check_message(&msg) - .map_err(|err| { - self.queue_unexpected_alert(); - err - })?; - self.state = Some(state.handle(self, msg)?); + let maybe_next_state = state.handle(self, msg); + let next_state = self.maybe_send_unexpected_alert(maybe_next_state)?; + self.state = Some(next_state); Ok(()) } diff --git a/rustls/src/client/tls12.rs b/rustls/src/client/tls12.rs index 781fe454..99797a9b 100644 --- a/rustls/src/client/tls12.rs +++ b/rustls/src/client/tls12.rs @@ -17,7 +17,7 @@ use crate::ticketer; #[cfg(feature = "logging")] use crate::log::{debug, trace}; use crate::error::TLSError; -use crate::handshake::{check_message, check_handshake_message}; +use crate::check::check_message; use crate::client::common::{ServerCertDetails, ServerKXDetails, HandshakeDetails}; use crate::client::common::{ReceivedTicketDetails, ClientAuthDetails}; @@ -52,12 +52,8 @@ impl ExpectCertificate { } impl hs::State for ExpectCertificate { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::Certificate]) - } - fn handle(mut self: Box, _sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - let cert_chain = extract_handshake!(m, HandshakePayload::Certificate).unwrap(); + let cert_chain = require_handshake_msg!(m, HandshakeType::Certificate, HandshakePayload::Certificate)?; self.handshake.transcript.add_message(&m); self.server_cert.cert_chain = cert_chain.clone(); @@ -87,13 +83,9 @@ impl ExpectCertificateStatus { } impl hs::State for ExpectCertificateStatus { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::CertificateStatus]) - } - fn handle(mut self: Box, _sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { self.handshake.transcript.add_message(&m); - let mut status = extract_handshake_mut!(m, HandshakePayload::CertificateStatus).unwrap(); + let mut status = require_handshake_msg_mut!(m, HandshakeType::CertificateStatus, HandshakePayload::CertificateStatus)?; self.server_cert.ocsp_response = status.take_ocsp_response(); debug!("Server stapled OCSP response is {:?}", self.server_cert.ocsp_response); @@ -126,13 +118,10 @@ impl ExpectCertificateStatusOrServerKX { } impl hs::State for ExpectCertificateStatusOrServerKX { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, - &[HandshakeType::ServerKeyExchange, - HandshakeType::CertificateStatus]) - } - fn handle(self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { + check_message(&m, + &[ContentType::Handshake], + &[HandshakeType::ServerKeyExchange, HandshakeType::CertificateStatus])?; if m.is_handshake_type(HandshakeType::ServerKeyExchange) { self.into_expect_server_kx().handle(sess, m) } else { @@ -159,12 +148,8 @@ impl ExpectServerKX { } impl hs::State for ExpectServerKX { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::ServerKeyExchange]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - let opaque_kx = extract_handshake!(m, HandshakePayload::ServerKeyExchange).unwrap(); + let opaque_kx = require_handshake_msg!(m, HandshakeType::ServerKeyExchange, HandshakePayload::ServerKeyExchange)?; let maybe_decoded_kx = opaque_kx.unwrap_given_kxa(&sess.common.get_suite_assert().kx); self.handshake.transcript.add_message(&m); @@ -312,12 +297,8 @@ impl ExpectCertificateRequest { } impl hs::State for ExpectCertificateRequest { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::CertificateRequest]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - let certreq = extract_handshake!(m, HandshakePayload::CertificateRequest).unwrap(); + let certreq = require_handshake_msg!(m, HandshakeType::CertificateRequest, HandshakePayload::CertificateRequest)?; self.handshake.transcript.add_message(&m); debug!("Got CertificateRequest {:?}", certreq); @@ -381,14 +362,8 @@ impl ExpectServerDoneOrCertReq { } impl hs::State for ExpectServerDoneOrCertReq { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, - &[HandshakeType::CertificateRequest, - HandshakeType::ServerHelloDone]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - if extract_handshake!(m, HandshakePayload::CertificateRequest).is_some() { + if require_handshake_msg!(m, HandshakeType::CertificateRequest, HandshakePayload::CertificateRequest).is_ok() { self.into_expect_certificate_req().handle(sess, m) } else { self.handshake.transcript.abandon_client_auth(); @@ -436,12 +411,9 @@ impl ExpectServerDone { } impl hs::State for ExpectServerDone { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::ServerHelloDone]) - } - fn handle(self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { let mut st = *self; + check_message(&m, &[ContentType::Handshake], &[HandshakeType::ServerHelloDone])?; st.handshake.transcript.add_message(&m); hs::check_aligned_handshake(sess)?; @@ -594,11 +566,8 @@ impl ExpectCCS { } impl hs::State for ExpectCCS { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_message(m, &[ContentType::ChangeCipherSpec], &[]) - } - - fn handle(self: Box, sess: &mut ClientSessionImpl, _m: Message) -> hs::NextStateOrError { + fn handle(self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { + check_message(&m, &[ContentType::ChangeCipherSpec], &[])?; // CCS should not be received interleaved with fragmented handshake-level // message. hs::check_aligned_handshake(sess)?; @@ -634,14 +603,10 @@ impl ExpectNewTicket { } impl hs::State for ExpectNewTicket { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::NewSessionTicket]) - } - fn handle(mut self: Box, _sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { self.handshake.transcript.add_message(&m); - let nst = extract_handshake_mut!(m, HandshakePayload::NewSessionTicket).unwrap(); + let nst = require_handshake_msg_mut!(m, HandshakeType::NewSessionTicket, HandshakePayload::NewSessionTicket)?; let recvd = ReceivedTicketDetails::from(nst.ticket.0, nst.lifetime_hint); Ok(self.into_expect_ccs(recvd)) } @@ -712,13 +677,9 @@ impl ExpectFinished { } impl hs::State for ExpectFinished { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::Finished]) - } - fn handle(self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { let mut st = *self; - let finished = extract_handshake!(m, HandshakePayload::Finished).unwrap(); + let finished = require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?; hs::check_aligned_handshake(sess)?; @@ -766,11 +727,8 @@ struct ExpectTraffic { } impl hs::State for ExpectTraffic { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_message(m, &[ContentType::ApplicationData], &[]) - } - fn handle(self: Box, sess: &mut ClientSessionImpl, mut m: Message) -> hs::NextStateOrError { + check_message(&m, &[ContentType::ApplicationData], &[])?; sess.common.take_received_plaintext(m.take_opaque_payload().unwrap()); Ok(self) } diff --git a/rustls/src/client/tls13.rs b/rustls/src/client/tls13.rs index c2de45d7..db9dfc01 100644 --- a/rustls/src/client/tls13.rs +++ b/rustls/src/client/tls13.rs @@ -10,6 +10,7 @@ use crate::msgs::handshake::EncryptedExtensions; use crate::msgs::handshake::{CertificatePayloadTLS13, CertificateEntry}; use crate::msgs::handshake::{PresharedKeyIdentity, PresharedKeyOffer}; use crate::msgs::handshake::DigitallySignedStruct; +use crate::msgs::handshake::NewSessionTicketPayloadTLS13; use crate::msgs::ccs::ChangeCipherSpecPayload; use crate::msgs::codec::Codec; use crate::msgs::persist; @@ -29,7 +30,7 @@ use crate::ticketer; #[cfg(feature = "logging")] use crate::log::{debug, warn}; use crate::error::TLSError; -use crate::handshake::{check_message, check_handshake_message}; +use crate::check::check_message; #[cfg(feature = "quic")] use crate::{ quic, @@ -377,16 +378,12 @@ impl ExpectEncryptedExtensions { } impl hs::State for ExpectEncryptedExtensions { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::EncryptedExtensions]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - let exts = extract_handshake!(m, HandshakePayload::EncryptedExtensions).unwrap(); + let exts = require_handshake_msg!(m, HandshakeType::EncryptedExtensions, HandshakePayload::EncryptedExtensions)?; debug!("TLS1.3 encrypted extensions: {:?}", exts); self.handshake.transcript.add_message(&m); - validate_encrypted_extensions(sess, &self.hello, exts)?; + validate_encrypted_extensions(sess, &self.hello, &exts)?; hs::process_alpn_protocol(sess, exts.get_alpn_protocol())?; #[cfg(feature = "quic")] { @@ -451,12 +448,8 @@ impl ExpectCertificate { } impl hs::State for ExpectCertificate { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::Certificate]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - let cert_chain = extract_handshake!(m, HandshakePayload::CertificateTLS13).unwrap(); + let cert_chain = require_handshake_msg!(m, HandshakeType::Certificate, HandshakePayload::CertificateTLS13)?; self.handshake.transcript.add_message(&m); // This is only non-empty for client auth. @@ -519,13 +512,10 @@ impl ExpectCertificateOrCertReq { } impl hs::State for ExpectCertificateOrCertReq { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, - &[HandshakeType::Certificate, - HandshakeType::CertificateRequest]) - } - fn handle(self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { + check_message(&m, + &[ContentType::Handshake], + &[HandshakeType::Certificate, HandshakeType::CertificateRequest])?; if m.is_handshake_type(HandshakeType::Certificate) { self.into_expect_certificate().handle(sess, m) } else { @@ -573,12 +563,8 @@ fn send_cert_error_alert(sess: &mut ClientSessionImpl, err: TLSError) -> TLSErro } impl hs::State for ExpectCertificateVerify { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::CertificateVerify]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - let cert_verify = extract_handshake!(m, HandshakePayload::CertificateVerify).unwrap(); + let cert_verify = require_handshake_msg!(m, HandshakeType::CertificateVerify, HandshakePayload::CertificateVerify)?; debug!("Server cert is {:?}", self.server_cert.cert_chain); @@ -601,7 +587,7 @@ impl hs::State for ExpectCertificateVerify { .get_verifier() .verify_tls13_signature(&verify::construct_tls13_server_verify_message(&handshake_hash), &self.server_cert.cert_chain[0], - cert_verify) + &cert_verify) .map_err(|err| send_cert_error_alert(sess, err))?; // 3. Verify any included SCTs. @@ -642,12 +628,8 @@ impl ExpectCertificateRequest { } impl hs::State for ExpectCertificateRequest { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_handshake_message(m, &[HandshakeType::CertificateRequest]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - let certreq = &extract_handshake!(m, HandshakePayload::CertificateRequestTLS13).unwrap(); + let certreq = &require_handshake_msg!(m, HandshakeType::CertificateRequest, HandshakePayload::CertificateRequestTLS13)?; self.handshake.transcript.add_message(&m); debug!("Got CertificateRequest {:?}", certreq); @@ -823,13 +805,9 @@ impl ExpectFinished { } impl hs::State for ExpectFinished { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_handshake_message(m, &[HandshakeType::Finished]) - } - fn handle(self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { let mut st = *self; - let finished = extract_handshake!(m, HandshakePayload::Finished).unwrap(); + let finished = require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?; let handshake_hash = st.handshake.transcript.get_current_hash(); let expect_verify_data = st.key_schedule.sign_server_finish(&handshake_hash); @@ -942,8 +920,7 @@ struct ExpectTraffic { } impl ExpectTraffic { - fn handle_new_ticket_tls13(&mut self, sess: &mut ClientSessionImpl, m: Message) -> Result<(), TLSError> { - let nst = extract_handshake!(m, HandshakePayload::NewSessionTicketTLS13).unwrap(); + fn handle_new_ticket_tls13(&mut self, sess: &mut ClientSessionImpl, nst: &NewSessionTicketPayloadTLS13) -> Result<(), TLSError> { let handshake_hash = self.handshake.transcript.get_current_hash(); let secret = self.key_schedule .resumption_master_secret_and_derive_ticket_psk(&handshake_hash, &nst.nonce.0); @@ -989,9 +966,7 @@ impl ExpectTraffic { Ok(()) } - fn handle_key_update(&mut self, sess: &mut ClientSessionImpl, m: Message) -> Result<(), TLSError> { - let kur = extract_handshake!(m, HandshakePayload::KeyUpdate).unwrap(); - + fn handle_key_update(&mut self, sess: &mut ClientSessionImpl, kur: &KeyUpdateRequest) -> Result<(), TLSError> { #[cfg(feature = "quic")] { if let Protocol::Quic = sess.common.protocol { @@ -1026,19 +1001,17 @@ impl ExpectTraffic { } impl hs::State for ExpectTraffic { - fn check_message(&self, m: &Message) -> Result<(), TLSError> { - check_message(m, - &[ContentType::ApplicationData, ContentType::Handshake], - &[HandshakeType::NewSessionTicket, HandshakeType::KeyUpdate]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, mut m: Message) -> hs::NextStateOrError { if m.is_content_type(ContentType::ApplicationData) { sess.common.take_received_plaintext(m.take_opaque_payload().unwrap()); - } else if m.is_handshake_type(HandshakeType::NewSessionTicket) { - self.handle_new_ticket_tls13(sess, m)?; - } else if m.is_handshake_type(HandshakeType::KeyUpdate) { - self.handle_key_update(sess, m)?; + } else if let Ok(ref new_ticket) = require_handshake_msg!(m, HandshakeType::NewSessionTicket, HandshakePayload::NewSessionTicketTLS13) { + self.handle_new_ticket_tls13(sess, new_ticket)?; + } else if let Ok(ref key_update) = require_handshake_msg!(m, HandshakeType::KeyUpdate, HandshakePayload::KeyUpdate) { + self.handle_key_update(sess, key_update)?; + } else { + check_message(&m, + &[ContentType::ApplicationData, ContentType::Handshake], + &[HandshakeType::NewSessionTicket, HandshakeType::KeyUpdate])?; } Ok(self) @@ -1068,12 +1041,9 @@ pub struct ExpectQUICTraffic(ExpectTraffic); #[cfg(feature = "quic")] impl hs::State for ExpectQUICTraffic { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_message(m, &[ContentType::Handshake], &[HandshakeType::NewSessionTicket]) - } - fn handle(mut self: Box, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError { - self.0.handle_new_ticket_tls13(sess, m)?; + let nst = require_handshake_msg!(m, HandshakeType::NewSessionTicket, HandshakePayload::NewSessionTicketTLS13)?; + self.0.handle_new_ticket_tls13(sess, nst)?; Ok(self) } } diff --git a/rustls/src/handshake.rs b/rustls/src/handshake.rs deleted file mode 100644 index 47f0a87a..00000000 --- a/rustls/src/handshake.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::msgs::enums::{ContentType, HandshakeType}; -use crate::msgs::message::{Message, MessagePayload}; -use crate::error::TLSError; -#[cfg(feature = "logging")] -use crate::log::warn; - -pub fn check_handshake_message(m: &Message, - handshake_types: &[HandshakeType]) -> Result<(), TLSError> { - check_message(m, - &[ContentType::Handshake], - handshake_types) -} - -pub fn check_message(m: &Message, - content_types: &[ContentType], - handshake_types: &[HandshakeType]) -> Result<(), TLSError> { - if !content_types.contains(&m.typ) { - warn!("Received a {:?} message while expecting {:?}", - m.typ, - content_types); - return Err(TLSError::InappropriateMessage { - expect_types: content_types.to_vec(), - got_type: m.typ, - }); - } - - if let MessagePayload::Handshake(ref hsp) = m.payload { - if !handshake_types.is_empty() && !handshake_types.contains(&hsp.typ) { - warn!("Received a {:?} handshake message while expecting {:?}", - hsp.typ, - handshake_types); - return Err(TLSError::InappropriateHandshakeMessage { - expect_types: handshake_types.to_vec(), - got_type: hsp.typ, - }); - } - } - - Ok(()) -} diff --git a/rustls/src/hash_hs.rs b/rustls/src/hash_hs.rs index 409844d6..429796b3 100644 --- a/rustls/src/hash_hs.rs +++ b/rustls/src/hash_hs.rs @@ -86,7 +86,7 @@ impl HandshakeHash { let buf = hs.get_encoding(); self.update_raw(&buf); } - _ => unreachable!(), + _ => {}, }; self } diff --git a/rustls/src/lib.rs b/rustls/src/lib.rs index 6168ca7f..8f2d7d37 100644 --- a/rustls/src/lib.rs +++ b/rustls/src/lib.rs @@ -234,7 +234,8 @@ mod anchors; mod verify; #[cfg(test)] mod verifybench; -mod handshake; +#[macro_use] +mod check; mod suites; mod ticketer; mod server; diff --git a/rustls/src/server/hs.rs b/rustls/src/server/hs.rs index c4854bf9..b5d8b23f 100644 --- a/rustls/src/server/hs.rs +++ b/rustls/src/server/hs.rs @@ -23,7 +23,6 @@ use crate::sign; #[cfg(feature = "logging")] use crate::log::{trace, debug}; use crate::error::TLSError; -use crate::handshake::check_handshake_message; use webpki; #[cfg(feature = "quic")] use crate::session::Protocol; @@ -31,24 +30,10 @@ use crate::session::Protocol; use crate::server::common::{HandshakeDetails, ServerKXDetails}; use crate::server::{tls12, tls13}; -macro_rules! extract_handshake( - ( $m:expr, $t:path ) => ( - match $m.payload { - MessagePayload::Handshake(ref hsp) => match hsp.payload { - $t(ref hm) => Some(hm), - _ => None - }, - _ => None - } - ) -); - -pub type CheckResult = Result<(), TLSError>; pub type NextState = Box; pub type NextStateOrError = Result; pub trait State { - fn check_message(&self, m: &Message) -> CheckResult; fn handle(self: Box, sess: &mut ServerSessionImpl, m: Message) -> NextStateOrError; fn export_keying_material(&self, @@ -558,12 +543,8 @@ impl ExpectClientHello { } impl State for ExpectClientHello { - fn check_message(&self, m: &Message) -> CheckResult { - check_handshake_message(m, &[HandshakeType::ClientHello]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> NextStateOrError { - let client_hello = extract_handshake!(m, HandshakePayload::ClientHello).unwrap(); + let client_hello = require_handshake_msg!(m, HandshakeType::ClientHello, HandshakePayload::ClientHello)?; let tls13_enabled = sess.config.supports_version(ProtocolVersion::TLSv1_3); let tls12_enabled = sess.config.supports_version(ProtocolVersion::TLSv1_2); trace!("we got a clienthello {:?}", client_hello); diff --git a/rustls/src/server/mod.rs b/rustls/src/server/mod.rs index 8c6906e9..6a4c60bf 100644 --- a/rustls/src/server/mod.rs +++ b/rustls/src/server/mod.rs @@ -405,6 +405,17 @@ impl ServerSessionImpl { self.common.send_fatal_alert(AlertDescription::UnexpectedMessage); } + fn maybe_send_unexpected_alert(&mut self, rc: hs::NextStateOrError) -> hs::NextStateOrError { + match rc { + Err(TLSError::InappropriateMessage { .. }) | + Err(TLSError::InappropriateHandshakeMessage { .. }) => { + self.queue_unexpected_alert(); + } + _ => {} + }; + rc + } + pub fn process_main_protocol(&mut self, msg: Message) -> Result<(), TLSError> { if self.common.traffic && !self.common.is_tls13() && msg.is_handshake_type(HandshakeType::ClientHello) { @@ -412,11 +423,10 @@ impl ServerSessionImpl { return Ok(()); } - let st = self.state.take().unwrap(); - st.check_message(&msg) - .map_err(|err| { self.queue_unexpected_alert(); err })?; - - self.state = Some(st.handle(self, msg)?); + let state = self.state.take().unwrap(); + let maybe_next_state = state.handle(self, msg); + let next_state = self.maybe_send_unexpected_alert(maybe_next_state)?; + self.state = Some(next_state); Ok(()) } diff --git a/rustls/src/server/tls12.rs b/rustls/src/server/tls12.rs index 65b3f794..cc65b884 100644 --- a/rustls/src/server/tls12.rs +++ b/rustls/src/server/tls12.rs @@ -14,7 +14,7 @@ use crate::verify; #[cfg(feature = "logging")] use crate::log::{trace, debug}; use crate::error::TLSError; -use crate::handshake::{check_handshake_message, check_message}; +use crate::check::check_message; use crate::server::common::{HandshakeDetails, ServerKXDetails, ClientCertDetails}; use crate::server::hs; @@ -40,12 +40,8 @@ impl ExpectCertificate { } impl hs::State for ExpectCertificate { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_handshake_message(m, &[HandshakeType::Certificate]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { - let cert_chain = extract_handshake!(m, HandshakePayload::Certificate).unwrap(); + let cert_chain = require_handshake_msg!(m, HandshakeType::Certificate, HandshakePayload::Certificate)?; self.handshake.transcript.add_message(&m); // If we can't determine if the auth is mandatory, abort @@ -108,12 +104,8 @@ impl ExpectClientKX { } impl hs::State for ExpectClientKX { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_handshake_message(m, &[HandshakeType::ClientKeyExchange]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { - let client_kx = extract_handshake!(m, HandshakePayload::ClientKeyExchange).unwrap(); + let client_kx = require_handshake_msg!(m, HandshakeType::ClientKeyExchange, HandshakePayload::ClientKeyExchange)?; self.handshake.transcript.add_message(&m); // Complete key agreement, and set up encryption with the @@ -173,13 +165,9 @@ impl ExpectCertificateVerify { } impl hs::State for ExpectCertificateVerify { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_handshake_message(m, &[HandshakeType::CertificateVerify]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { let rc = { - let sig = extract_handshake!(m, HandshakePayload::CertificateVerify).unwrap(); + let sig = require_handshake_msg!(m, HandshakeType::CertificateVerify, HandshakePayload::CertificateVerify)?; let handshake_msgs = self.handshake.transcript.take_handshake_buf(); let certs = &self.client_cert.cert_chain; @@ -221,11 +209,9 @@ impl ExpectCCS { } impl hs::State for ExpectCCS { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_message(m, &[ContentType::ChangeCipherSpec], &[]) - } + fn handle(self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { + check_message(&m, &[ContentType::ChangeCipherSpec], &[])?; - fn handle(self: Box, sess: &mut ServerSessionImpl, _m: Message) -> hs::NextStateOrError { // CCS should not be received interleaved with fragmented handshake-level // message. hs::check_aligned_handshake(sess)?; @@ -335,12 +321,8 @@ impl ExpectFinished { } impl hs::State for ExpectFinished { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_handshake_message(m, &[HandshakeType::Finished]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { - let finished = extract_handshake!(m, HandshakePayload::Finished).unwrap(); + let finished = require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?; hs::check_aligned_handshake(sess)?; @@ -397,11 +379,8 @@ impl ExpectTraffic { } impl hs::State for ExpectTraffic { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_message(m, &[ContentType::ApplicationData], &[]) - } - fn handle(self: Box, sess: &mut ServerSessionImpl, mut m: Message) -> hs::NextStateOrError { + check_message(&m, &[ContentType::ApplicationData], &[])?; sess.common.take_received_plaintext(m.take_opaque_payload().unwrap()); Ok(self) } diff --git a/rustls/src/server/tls13.rs b/rustls/src/server/tls13.rs index 6e912d6e..6aeae320 100644 --- a/rustls/src/server/tls13.rs +++ b/rustls/src/server/tls13.rs @@ -41,7 +41,7 @@ use crate::suites; #[cfg(feature = "logging")] use crate::log::{warn, trace, debug}; use crate::error::TLSError; -use crate::handshake::{check_handshake_message, check_message}; +use crate::check::check_message; #[cfg(feature = "quic")] use crate::{ quic, @@ -467,7 +467,7 @@ impl CompleteClientHelloHandling { sni: Option, mut server_key: sign::CertifiedKey, chm: &Message) -> hs::NextStateOrError { - let client_hello = extract_handshake!(chm, HandshakePayload::ClientHello).unwrap(); + let client_hello = require_handshake_msg!(chm, HandshakeType::ClientHello, HandshakePayload::ClientHello)?; if client_hello.compression_methods.len() != 1 { return Err(hs::illegal_param(sess, "client offered wrong compressions")); @@ -636,12 +636,8 @@ impl ExpectCertificate { } impl hs::State for ExpectCertificate { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_handshake_message(m, &[HandshakeType::Certificate]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { - let certp = extract_handshake!(m, HandshakePayload::CertificateTLS13).unwrap(); + let certp = require_handshake_msg!(m, HandshakeType::Certificate, HandshakePayload::CertificateTLS13)?; self.handshake.transcript.add_message(&m); // We don't send any CertificateRequest extensions, so any extensions @@ -700,13 +696,9 @@ impl ExpectCertificateVerify { } impl hs::State for ExpectCertificateVerify { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_handshake_message(m, &[HandshakeType::CertificateVerify]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { let rc = { - let sig = extract_handshake!(m, HandshakePayload::CertificateVerify).unwrap(); + let sig = require_handshake_msg!(m, HandshakeType::CertificateVerify, HandshakePayload::CertificateVerify)?; let handshake_hash = self.handshake.transcript.get_current_hash(); self.handshake.transcript.abandon_client_auth(); let certs = &self.client_cert.cert_chain; @@ -849,12 +841,8 @@ impl ExpectFinished { } impl hs::State for ExpectFinished { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_handshake_message(m, &[HandshakeType::Finished]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { - let finished = extract_handshake!(m, HandshakePayload::Finished).unwrap(); + let finished = require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?; let handshake_hash = self.handshake.transcript.get_current_hash(); let expect_verify_data = self.key_schedule.sign_client_finish(&handshake_hash); @@ -919,9 +907,7 @@ impl ExpectTraffic { Ok(()) } - fn handle_key_update(&mut self, sess: &mut ServerSessionImpl, m: Message) -> Result<(), TLSError> { - let kur = extract_handshake!(m, HandshakePayload::KeyUpdate).unwrap(); - + fn handle_key_update(&mut self, sess: &mut ServerSessionImpl, kur: &KeyUpdateRequest) -> Result<(), TLSError> { #[cfg(feature = "quic")] { if let Protocol::Quic = sess.common.protocol { @@ -955,17 +941,15 @@ impl ExpectTraffic { } impl hs::State for ExpectTraffic { - fn check_message(&self, m: &Message) -> hs::CheckResult { - check_message(m, - &[ContentType::ApplicationData, ContentType::Handshake], - &[HandshakeType::KeyUpdate]) - } - fn handle(mut self: Box, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { if m.is_content_type(ContentType::ApplicationData) { self.handle_traffic(sess, m)?; - } else if m.is_handshake_type(HandshakeType::KeyUpdate) { - self.handle_key_update(sess, m)?; + } else if let Ok(key_update) = require_handshake_msg!(m, HandshakeType::KeyUpdate, HandshakePayload::KeyUpdate) { + self.handle_key_update(sess, key_update)?; + } else { + check_message(&m, + &[ContentType::ApplicationData, ContentType::Handshake], + &[HandshakeType::KeyUpdate])?; } Ok(self) @@ -997,14 +981,9 @@ pub struct ExpectQUICTraffic { #[cfg(feature = "quic")] impl hs::State for ExpectQUICTraffic { - fn check_message(&self, m: &Message) -> hs::CheckResult { - Err(TLSError::InappropriateMessage { - expect_types: Vec::new(), - got_type: m.typ, - }) - } - - fn handle(self: Box, _: &mut ServerSessionImpl, _: Message) -> hs::NextStateOrError { - unreachable!("check_message always fails"); + fn handle(self: Box, _: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError { + // reject all messages + check_message(&m, &[], &[])?; + unreachable!(); } }