Refactor state machine message checking

Instead of having check_message called separately, do all
checking inside the state transition functions.

This means certain errors need to be detected to get
the right alert behaviour.  But it dramatically
decreases the number of .unwrap()s and makes things
simpler.
This commit is contained in:
Joseph Birr-Pixton 2020-06-07 11:34:29 +01:00
parent 46c259bd8e
commit 39175e7252
12 changed files with 186 additions and 299 deletions

75
rustls/src/check.rs Normal file
View File

@ -0,0 +1,75 @@
use crate::msgs::enums::{ContentType, HandshakeType};
use crate::msgs::message::{Message, MessagePayload};
use crate::error::TLSError;
#[cfg(feature = "logging")]
use crate::log::warn;
/// For a Message $m, and a HandshakePayload enum member $payload_type,
/// return Ok(payload) if $m is both a handshake message and one that
/// has the given $payload_type. If not, return Err(TLSError) quoting
/// $handshake_type as the expected handshake type.
macro_rules! require_handshake_msg(
( $m:expr, $handshake_type:path, $payload_type:path ) => (
match $m.payload {
MessagePayload::Handshake(ref hsp) => match hsp.payload {
$payload_type(ref hm) => Ok(hm),
_ => Err(TLSError::InappropriateHandshakeMessage {
expect_types: vec![ $handshake_type ],
got_type: hsp.typ})
}
_ => Err(TLSError::InappropriateMessage {
expect_types: vec![ ContentType::Handshake ],
got_type: $m.typ})
}
)
);
/// Like require_handshake_msg, but moves the payload out of $m.
macro_rules! require_handshake_msg_mut(
( $m:expr, $handshake_type:path, $payload_type:path ) => (
match $m.payload {
MessagePayload::Handshake(hsp) => match hsp.payload {
$payload_type(hm) => Ok(hm),
_ => Err(TLSError::InappropriateHandshakeMessage {
expect_types: vec![ $handshake_type ],
got_type: hsp.typ})
}
_ => Err(TLSError::InappropriateMessage {
expect_types: vec![ ContentType::Handshake ],
got_type: $m.typ})
}
)
);
/// Validate the message `m`: return an error if:
///
/// - the type of m does not appear in `content_types`.
/// - if m is a handshake message, the handshake message type does
/// not appear in `handshake_types`.
pub fn check_message(m: &Message,
content_types: &[ContentType],
handshake_types: &[HandshakeType]) -> Result<(), TLSError> {
if !content_types.contains(&m.typ) {
warn!("Received a {:?} message while expecting {:?}",
m.typ,
content_types);
return Err(TLSError::InappropriateMessage {
expect_types: content_types.to_vec(),
got_type: m.typ,
});
}
if let MessagePayload::Handshake(ref hsp) = m.payload {
if !handshake_types.is_empty() && !handshake_types.contains(&hsp.typ) {
warn!("Received a {:?} handshake message while expecting {:?}",
hsp.typ,
handshake_types);
return Err(TLSError::InappropriateHandshakeMessage {
expect_types: handshake_types.to_vec(),
got_type: hsp.typ,
});
}
}
Ok(())
}

View File

@ -22,8 +22,8 @@ use crate::rand;
use crate::ticketer;
#[cfg(feature = "logging")]
use crate::log::{debug, trace};
use crate::check::check_message;
use crate::error::TLSError;
use crate::handshake::check_handshake_message;
#[cfg(feature = "quic")]
use crate::msgs::base::PayloadU16;
@ -33,36 +33,12 @@ use crate::client::{tls12, tls13};
use webpki;
macro_rules! extract_handshake(
( $m:expr, $t:path ) => (
match $m.payload {
MessagePayload::Handshake(ref hsp) => match hsp.payload {
$t(ref hm) => Some(hm),
_ => None
},
_ => None
}
)
);
macro_rules! extract_handshake_mut(
( $m:expr, $t:path ) => (
match $m.payload {
MessagePayload::Handshake(hsp) => match hsp.payload {
$t(hm) => Some(hm),
_ => None
},
_ => None
}
)
);
pub type CheckResult = Result<(), TLSError>;
pub type NextState = Box<dyn State + Send + Sync>;
pub type NextStateOrError = Result<NextState, TLSError>;
pub trait State {
fn check_message(&self, m: &Message) -> CheckResult;
/// Each handle() implementation consumes a whole TLS message, and returns
/// either an error or the next state.
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> NextStateOrError;
fn export_keying_material(&self,
@ -441,12 +417,8 @@ impl ExpectServerHello {
}
impl State for ExpectServerHello {
fn check_message(&self, m: &Message) -> CheckResult {
check_handshake_message(m, &[HandshakeType::ServerHello])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> NextStateOrError {
let server_hello = extract_handshake!(m, HandshakePayload::ServerHello).unwrap();
let server_hello = require_handshake_msg!(m, HandshakeType::ServerHello, HandshakePayload::ServerHello)?;
trace!("We got ServerHello {:#?}", server_hello);
use crate::ProtocolVersion::{TLSv1_2, TLSv1_3};
@ -540,10 +512,10 @@ impl State for ExpectServerHello {
// For TLS1.3, start message encryption using
// handshake_traffic_secret.
if sess.common.is_tls13() {
tls13::validate_server_hello(sess, server_hello)?;
tls13::validate_server_hello(sess, &server_hello)?;
let key_schedule = tls13::start_handshake_traffic(sess,
self.early_key_schedule.take(),
server_hello,
&server_hello,
&mut self.handshake,
&mut self.hello)?;
tls13::emit_fake_ccs(&mut self.handshake, sess);
@ -641,9 +613,7 @@ impl ExpectServerHelloOrHelloRetryRequest {
}
fn handle_hello_retry_request(mut self, sess: &mut ClientSessionImpl, m: Message) -> NextStateOrError {
check_handshake_message(&m, &[HandshakeType::HelloRetryRequest])?;
let hrr = extract_handshake!(m, HandshakePayload::HelloRetryRequest).unwrap();
let hrr = require_handshake_msg!(m, HandshakeType::HelloRetryRequest, HandshakePayload::HelloRetryRequest)?;
trace!("Got HRR {:?}", hrr);
check_aligned_handshake(sess)?;
@ -721,18 +691,15 @@ impl ExpectServerHelloOrHelloRetryRequest {
Ok(emit_client_hello_for_retry(sess,
self.0.handshake,
self.0.hello,
Some(hrr)))
Some(&hrr)))
}
}
impl State for ExpectServerHelloOrHelloRetryRequest {
fn check_message(&self, m: &Message) -> CheckResult {
check_handshake_message(m,
&[HandshakeType::ServerHello,
HandshakeType::HelloRetryRequest])
}
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> NextStateOrError {
check_message(&m,
&[ContentType::Handshake],
&[HandshakeType::ServerHello, HandshakeType::HelloRetryRequest])?;
if m.is_handshake_type(HandshakeType::ServerHello) {
self.into_expect_server_hello().handle(sess, m)
} else {

View File

@ -371,7 +371,7 @@ pub struct ClientSessionImpl {
pub alpn_protocol: Option<Vec<u8>>,
pub common: SessionCommon,
pub error: Option<TLSError>,
pub state: Option<Box<dyn hs::State + Send + Sync>>,
pub state: Option<hs::NextState>,
pub server_cert_chain: CertificatePayload,
pub early_data: EarlyData,
pub resumption_ciphersuite: Option<&'static SupportedCipherSuite>,
@ -495,13 +495,24 @@ impl ClientSessionImpl {
Ok(())
}
fn reject_renegotiation_attempt(&mut self) -> Result<(), TLSError> {
self.common.send_warning_alert(AlertDescription::NoRenegotiation);
Ok(())
}
fn queue_unexpected_alert(&mut self) {
self.common.send_fatal_alert(AlertDescription::UnexpectedMessage);
}
fn reject_renegotiation_attempt(&mut self) -> Result<(), TLSError> {
self.common.send_warning_alert(AlertDescription::NoRenegotiation);
Ok(())
fn maybe_send_unexpected_alert(&mut self, rc: hs::NextStateOrError) -> hs::NextStateOrError {
match rc {
Err(TLSError::InappropriateMessage { .. }) |
Err(TLSError::InappropriateHandshakeMessage { .. }) => {
self.queue_unexpected_alert();
}
_ => {}
};
rc
}
/// Process `msg`. First, we get the current state. Then we ask what messages
@ -517,13 +528,9 @@ impl ClientSessionImpl {
}
let state = self.state.take().unwrap();
state
.check_message(&msg)
.map_err(|err| {
self.queue_unexpected_alert();
err
})?;
self.state = Some(state.handle(self, msg)?);
let maybe_next_state = state.handle(self, msg);
let next_state = self.maybe_send_unexpected_alert(maybe_next_state)?;
self.state = Some(next_state);
Ok(())
}

View File

@ -17,7 +17,7 @@ use crate::ticketer;
#[cfg(feature = "logging")]
use crate::log::{debug, trace};
use crate::error::TLSError;
use crate::handshake::{check_message, check_handshake_message};
use crate::check::check_message;
use crate::client::common::{ServerCertDetails, ServerKXDetails, HandshakeDetails};
use crate::client::common::{ReceivedTicketDetails, ClientAuthDetails};
@ -52,12 +52,8 @@ impl ExpectCertificate {
}
impl hs::State for ExpectCertificate {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::Certificate])
}
fn handle(mut self: Box<Self>, _sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let cert_chain = extract_handshake!(m, HandshakePayload::Certificate).unwrap();
let cert_chain = require_handshake_msg!(m, HandshakeType::Certificate, HandshakePayload::Certificate)?;
self.handshake.transcript.add_message(&m);
self.server_cert.cert_chain = cert_chain.clone();
@ -87,13 +83,9 @@ impl ExpectCertificateStatus {
}
impl hs::State for ExpectCertificateStatus {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::CertificateStatus])
}
fn handle(mut self: Box<Self>, _sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
self.handshake.transcript.add_message(&m);
let mut status = extract_handshake_mut!(m, HandshakePayload::CertificateStatus).unwrap();
let mut status = require_handshake_msg_mut!(m, HandshakeType::CertificateStatus, HandshakePayload::CertificateStatus)?;
self.server_cert.ocsp_response = status.take_ocsp_response();
debug!("Server stapled OCSP response is {:?}", self.server_cert.ocsp_response);
@ -126,13 +118,10 @@ impl ExpectCertificateStatusOrServerKX {
}
impl hs::State for ExpectCertificateStatusOrServerKX {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m,
&[HandshakeType::ServerKeyExchange,
HandshakeType::CertificateStatus])
}
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
check_message(&m,
&[ContentType::Handshake],
&[HandshakeType::ServerKeyExchange, HandshakeType::CertificateStatus])?;
if m.is_handshake_type(HandshakeType::ServerKeyExchange) {
self.into_expect_server_kx().handle(sess, m)
} else {
@ -159,12 +148,8 @@ impl ExpectServerKX {
}
impl hs::State for ExpectServerKX {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::ServerKeyExchange])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let opaque_kx = extract_handshake!(m, HandshakePayload::ServerKeyExchange).unwrap();
let opaque_kx = require_handshake_msg!(m, HandshakeType::ServerKeyExchange, HandshakePayload::ServerKeyExchange)?;
let maybe_decoded_kx = opaque_kx.unwrap_given_kxa(&sess.common.get_suite_assert().kx);
self.handshake.transcript.add_message(&m);
@ -312,12 +297,8 @@ impl ExpectCertificateRequest {
}
impl hs::State for ExpectCertificateRequest {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::CertificateRequest])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let certreq = extract_handshake!(m, HandshakePayload::CertificateRequest).unwrap();
let certreq = require_handshake_msg!(m, HandshakeType::CertificateRequest, HandshakePayload::CertificateRequest)?;
self.handshake.transcript.add_message(&m);
debug!("Got CertificateRequest {:?}", certreq);
@ -381,14 +362,8 @@ impl ExpectServerDoneOrCertReq {
}
impl hs::State for ExpectServerDoneOrCertReq {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m,
&[HandshakeType::CertificateRequest,
HandshakeType::ServerHelloDone])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
if extract_handshake!(m, HandshakePayload::CertificateRequest).is_some() {
if require_handshake_msg!(m, HandshakeType::CertificateRequest, HandshakePayload::CertificateRequest).is_ok() {
self.into_expect_certificate_req().handle(sess, m)
} else {
self.handshake.transcript.abandon_client_auth();
@ -436,12 +411,9 @@ impl ExpectServerDone {
}
impl hs::State for ExpectServerDone {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::ServerHelloDone])
}
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let mut st = *self;
check_message(&m, &[ContentType::Handshake], &[HandshakeType::ServerHelloDone])?;
st.handshake.transcript.add_message(&m);
hs::check_aligned_handshake(sess)?;
@ -594,11 +566,8 @@ impl ExpectCCS {
}
impl hs::State for ExpectCCS {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_message(m, &[ContentType::ChangeCipherSpec], &[])
}
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, _m: Message) -> hs::NextStateOrError {
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
check_message(&m, &[ContentType::ChangeCipherSpec], &[])?;
// CCS should not be received interleaved with fragmented handshake-level
// message.
hs::check_aligned_handshake(sess)?;
@ -634,14 +603,10 @@ impl ExpectNewTicket {
}
impl hs::State for ExpectNewTicket {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::NewSessionTicket])
}
fn handle(mut self: Box<Self>, _sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
self.handshake.transcript.add_message(&m);
let nst = extract_handshake_mut!(m, HandshakePayload::NewSessionTicket).unwrap();
let nst = require_handshake_msg_mut!(m, HandshakeType::NewSessionTicket, HandshakePayload::NewSessionTicket)?;
let recvd = ReceivedTicketDetails::from(nst.ticket.0, nst.lifetime_hint);
Ok(self.into_expect_ccs(recvd))
}
@ -712,13 +677,9 @@ impl ExpectFinished {
}
impl hs::State for ExpectFinished {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::Finished])
}
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let mut st = *self;
let finished = extract_handshake!(m, HandshakePayload::Finished).unwrap();
let finished = require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
hs::check_aligned_handshake(sess)?;
@ -766,11 +727,8 @@ struct ExpectTraffic {
}
impl hs::State for ExpectTraffic {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_message(m, &[ContentType::ApplicationData], &[])
}
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, mut m: Message) -> hs::NextStateOrError {
check_message(&m, &[ContentType::ApplicationData], &[])?;
sess.common.take_received_plaintext(m.take_opaque_payload().unwrap());
Ok(self)
}

View File

@ -10,6 +10,7 @@ use crate::msgs::handshake::EncryptedExtensions;
use crate::msgs::handshake::{CertificatePayloadTLS13, CertificateEntry};
use crate::msgs::handshake::{PresharedKeyIdentity, PresharedKeyOffer};
use crate::msgs::handshake::DigitallySignedStruct;
use crate::msgs::handshake::NewSessionTicketPayloadTLS13;
use crate::msgs::ccs::ChangeCipherSpecPayload;
use crate::msgs::codec::Codec;
use crate::msgs::persist;
@ -29,7 +30,7 @@ use crate::ticketer;
#[cfg(feature = "logging")]
use crate::log::{debug, warn};
use crate::error::TLSError;
use crate::handshake::{check_message, check_handshake_message};
use crate::check::check_message;
#[cfg(feature = "quic")]
use crate::{
quic,
@ -377,16 +378,12 @@ impl ExpectEncryptedExtensions {
}
impl hs::State for ExpectEncryptedExtensions {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::EncryptedExtensions])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let exts = extract_handshake!(m, HandshakePayload::EncryptedExtensions).unwrap();
let exts = require_handshake_msg!(m, HandshakeType::EncryptedExtensions, HandshakePayload::EncryptedExtensions)?;
debug!("TLS1.3 encrypted extensions: {:?}", exts);
self.handshake.transcript.add_message(&m);
validate_encrypted_extensions(sess, &self.hello, exts)?;
validate_encrypted_extensions(sess, &self.hello, &exts)?;
hs::process_alpn_protocol(sess, exts.get_alpn_protocol())?;
#[cfg(feature = "quic")] {
@ -451,12 +448,8 @@ impl ExpectCertificate {
}
impl hs::State for ExpectCertificate {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::Certificate])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let cert_chain = extract_handshake!(m, HandshakePayload::CertificateTLS13).unwrap();
let cert_chain = require_handshake_msg!(m, HandshakeType::Certificate, HandshakePayload::CertificateTLS13)?;
self.handshake.transcript.add_message(&m);
// This is only non-empty for client auth.
@ -519,13 +512,10 @@ impl ExpectCertificateOrCertReq {
}
impl hs::State for ExpectCertificateOrCertReq {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m,
&[HandshakeType::Certificate,
HandshakeType::CertificateRequest])
}
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
check_message(&m,
&[ContentType::Handshake],
&[HandshakeType::Certificate, HandshakeType::CertificateRequest])?;
if m.is_handshake_type(HandshakeType::Certificate) {
self.into_expect_certificate().handle(sess, m)
} else {
@ -573,12 +563,8 @@ fn send_cert_error_alert(sess: &mut ClientSessionImpl, err: TLSError) -> TLSErro
}
impl hs::State for ExpectCertificateVerify {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::CertificateVerify])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let cert_verify = extract_handshake!(m, HandshakePayload::CertificateVerify).unwrap();
let cert_verify = require_handshake_msg!(m, HandshakeType::CertificateVerify, HandshakePayload::CertificateVerify)?;
debug!("Server cert is {:?}", self.server_cert.cert_chain);
@ -601,7 +587,7 @@ impl hs::State for ExpectCertificateVerify {
.get_verifier()
.verify_tls13_signature(&verify::construct_tls13_server_verify_message(&handshake_hash),
&self.server_cert.cert_chain[0],
cert_verify)
&cert_verify)
.map_err(|err| send_cert_error_alert(sess, err))?;
// 3. Verify any included SCTs.
@ -642,12 +628,8 @@ impl ExpectCertificateRequest {
}
impl hs::State for ExpectCertificateRequest {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_handshake_message(m, &[HandshakeType::CertificateRequest])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let certreq = &extract_handshake!(m, HandshakePayload::CertificateRequestTLS13).unwrap();
let certreq = &require_handshake_msg!(m, HandshakeType::CertificateRequest, HandshakePayload::CertificateRequestTLS13)?;
self.handshake.transcript.add_message(&m);
debug!("Got CertificateRequest {:?}", certreq);
@ -823,13 +805,9 @@ impl ExpectFinished {
}
impl hs::State for ExpectFinished {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_handshake_message(m, &[HandshakeType::Finished])
}
fn handle(self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
let mut st = *self;
let finished = extract_handshake!(m, HandshakePayload::Finished).unwrap();
let finished = require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
let handshake_hash = st.handshake.transcript.get_current_hash();
let expect_verify_data = st.key_schedule.sign_server_finish(&handshake_hash);
@ -942,8 +920,7 @@ struct ExpectTraffic {
}
impl ExpectTraffic {
fn handle_new_ticket_tls13(&mut self, sess: &mut ClientSessionImpl, m: Message) -> Result<(), TLSError> {
let nst = extract_handshake!(m, HandshakePayload::NewSessionTicketTLS13).unwrap();
fn handle_new_ticket_tls13(&mut self, sess: &mut ClientSessionImpl, nst: &NewSessionTicketPayloadTLS13) -> Result<(), TLSError> {
let handshake_hash = self.handshake.transcript.get_current_hash();
let secret = self.key_schedule
.resumption_master_secret_and_derive_ticket_psk(&handshake_hash, &nst.nonce.0);
@ -989,9 +966,7 @@ impl ExpectTraffic {
Ok(())
}
fn handle_key_update(&mut self, sess: &mut ClientSessionImpl, m: Message) -> Result<(), TLSError> {
let kur = extract_handshake!(m, HandshakePayload::KeyUpdate).unwrap();
fn handle_key_update(&mut self, sess: &mut ClientSessionImpl, kur: &KeyUpdateRequest) -> Result<(), TLSError> {
#[cfg(feature = "quic")]
{
if let Protocol::Quic = sess.common.protocol {
@ -1026,19 +1001,17 @@ impl ExpectTraffic {
}
impl hs::State for ExpectTraffic {
fn check_message(&self, m: &Message) -> Result<(), TLSError> {
check_message(m,
&[ContentType::ApplicationData, ContentType::Handshake],
&[HandshakeType::NewSessionTicket, HandshakeType::KeyUpdate])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, mut m: Message) -> hs::NextStateOrError {
if m.is_content_type(ContentType::ApplicationData) {
sess.common.take_received_plaintext(m.take_opaque_payload().unwrap());
} else if m.is_handshake_type(HandshakeType::NewSessionTicket) {
self.handle_new_ticket_tls13(sess, m)?;
} else if m.is_handshake_type(HandshakeType::KeyUpdate) {
self.handle_key_update(sess, m)?;
} else if let Ok(ref new_ticket) = require_handshake_msg!(m, HandshakeType::NewSessionTicket, HandshakePayload::NewSessionTicketTLS13) {
self.handle_new_ticket_tls13(sess, new_ticket)?;
} else if let Ok(ref key_update) = require_handshake_msg!(m, HandshakeType::KeyUpdate, HandshakePayload::KeyUpdate) {
self.handle_key_update(sess, key_update)?;
} else {
check_message(&m,
&[ContentType::ApplicationData, ContentType::Handshake],
&[HandshakeType::NewSessionTicket, HandshakeType::KeyUpdate])?;
}
Ok(self)
@ -1068,12 +1041,9 @@ pub struct ExpectQUICTraffic(ExpectTraffic);
#[cfg(feature = "quic")]
impl hs::State for ExpectQUICTraffic {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_message(m, &[ContentType::Handshake], &[HandshakeType::NewSessionTicket])
}
fn handle(mut self: Box<Self>, sess: &mut ClientSessionImpl, m: Message) -> hs::NextStateOrError {
self.0.handle_new_ticket_tls13(sess, m)?;
let nst = require_handshake_msg!(m, HandshakeType::NewSessionTicket, HandshakePayload::NewSessionTicketTLS13)?;
self.0.handle_new_ticket_tls13(sess, nst)?;
Ok(self)
}
}

View File

@ -1,40 +0,0 @@
use crate::msgs::enums::{ContentType, HandshakeType};
use crate::msgs::message::{Message, MessagePayload};
use crate::error::TLSError;
#[cfg(feature = "logging")]
use crate::log::warn;
pub fn check_handshake_message(m: &Message,
handshake_types: &[HandshakeType]) -> Result<(), TLSError> {
check_message(m,
&[ContentType::Handshake],
handshake_types)
}
pub fn check_message(m: &Message,
content_types: &[ContentType],
handshake_types: &[HandshakeType]) -> Result<(), TLSError> {
if !content_types.contains(&m.typ) {
warn!("Received a {:?} message while expecting {:?}",
m.typ,
content_types);
return Err(TLSError::InappropriateMessage {
expect_types: content_types.to_vec(),
got_type: m.typ,
});
}
if let MessagePayload::Handshake(ref hsp) = m.payload {
if !handshake_types.is_empty() && !handshake_types.contains(&hsp.typ) {
warn!("Received a {:?} handshake message while expecting {:?}",
hsp.typ,
handshake_types);
return Err(TLSError::InappropriateHandshakeMessage {
expect_types: handshake_types.to_vec(),
got_type: hsp.typ,
});
}
}
Ok(())
}

View File

@ -86,7 +86,7 @@ impl HandshakeHash {
let buf = hs.get_encoding();
self.update_raw(&buf);
}
_ => unreachable!(),
_ => {},
};
self
}

View File

@ -234,7 +234,8 @@ mod anchors;
mod verify;
#[cfg(test)]
mod verifybench;
mod handshake;
#[macro_use]
mod check;
mod suites;
mod ticketer;
mod server;

View File

@ -23,7 +23,6 @@ use crate::sign;
#[cfg(feature = "logging")]
use crate::log::{trace, debug};
use crate::error::TLSError;
use crate::handshake::check_handshake_message;
use webpki;
#[cfg(feature = "quic")]
use crate::session::Protocol;
@ -31,24 +30,10 @@ use crate::session::Protocol;
use crate::server::common::{HandshakeDetails, ServerKXDetails};
use crate::server::{tls12, tls13};
macro_rules! extract_handshake(
( $m:expr, $t:path ) => (
match $m.payload {
MessagePayload::Handshake(ref hsp) => match hsp.payload {
$t(ref hm) => Some(hm),
_ => None
},
_ => None
}
)
);
pub type CheckResult = Result<(), TLSError>;
pub type NextState = Box<dyn State + Send + Sync>;
pub type NextStateOrError = Result<NextState, TLSError>;
pub trait State {
fn check_message(&self, m: &Message) -> CheckResult;
fn handle(self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> NextStateOrError;
fn export_keying_material(&self,
@ -558,12 +543,8 @@ impl ExpectClientHello {
}
impl State for ExpectClientHello {
fn check_message(&self, m: &Message) -> CheckResult {
check_handshake_message(m, &[HandshakeType::ClientHello])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> NextStateOrError {
let client_hello = extract_handshake!(m, HandshakePayload::ClientHello).unwrap();
let client_hello = require_handshake_msg!(m, HandshakeType::ClientHello, HandshakePayload::ClientHello)?;
let tls13_enabled = sess.config.supports_version(ProtocolVersion::TLSv1_3);
let tls12_enabled = sess.config.supports_version(ProtocolVersion::TLSv1_2);
trace!("we got a clienthello {:?}", client_hello);

View File

@ -405,6 +405,17 @@ impl ServerSessionImpl {
self.common.send_fatal_alert(AlertDescription::UnexpectedMessage);
}
fn maybe_send_unexpected_alert(&mut self, rc: hs::NextStateOrError) -> hs::NextStateOrError {
match rc {
Err(TLSError::InappropriateMessage { .. }) |
Err(TLSError::InappropriateHandshakeMessage { .. }) => {
self.queue_unexpected_alert();
}
_ => {}
};
rc
}
pub fn process_main_protocol(&mut self, msg: Message) -> Result<(), TLSError> {
if self.common.traffic && !self.common.is_tls13() &&
msg.is_handshake_type(HandshakeType::ClientHello) {
@ -412,11 +423,10 @@ impl ServerSessionImpl {
return Ok(());
}
let st = self.state.take().unwrap();
st.check_message(&msg)
.map_err(|err| { self.queue_unexpected_alert(); err })?;
self.state = Some(st.handle(self, msg)?);
let state = self.state.take().unwrap();
let maybe_next_state = state.handle(self, msg);
let next_state = self.maybe_send_unexpected_alert(maybe_next_state)?;
self.state = Some(next_state);
Ok(())
}

View File

@ -14,7 +14,7 @@ use crate::verify;
#[cfg(feature = "logging")]
use crate::log::{trace, debug};
use crate::error::TLSError;
use crate::handshake::{check_handshake_message, check_message};
use crate::check::check_message;
use crate::server::common::{HandshakeDetails, ServerKXDetails, ClientCertDetails};
use crate::server::hs;
@ -40,12 +40,8 @@ impl ExpectCertificate {
}
impl hs::State for ExpectCertificate {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_handshake_message(m, &[HandshakeType::Certificate])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
let cert_chain = extract_handshake!(m, HandshakePayload::Certificate).unwrap();
let cert_chain = require_handshake_msg!(m, HandshakeType::Certificate, HandshakePayload::Certificate)?;
self.handshake.transcript.add_message(&m);
// If we can't determine if the auth is mandatory, abort
@ -108,12 +104,8 @@ impl ExpectClientKX {
}
impl hs::State for ExpectClientKX {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_handshake_message(m, &[HandshakeType::ClientKeyExchange])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
let client_kx = extract_handshake!(m, HandshakePayload::ClientKeyExchange).unwrap();
let client_kx = require_handshake_msg!(m, HandshakeType::ClientKeyExchange, HandshakePayload::ClientKeyExchange)?;
self.handshake.transcript.add_message(&m);
// Complete key agreement, and set up encryption with the
@ -173,13 +165,9 @@ impl ExpectCertificateVerify {
}
impl hs::State for ExpectCertificateVerify {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_handshake_message(m, &[HandshakeType::CertificateVerify])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
let rc = {
let sig = extract_handshake!(m, HandshakePayload::CertificateVerify).unwrap();
let sig = require_handshake_msg!(m, HandshakeType::CertificateVerify, HandshakePayload::CertificateVerify)?;
let handshake_msgs = self.handshake.transcript.take_handshake_buf();
let certs = &self.client_cert.cert_chain;
@ -221,11 +209,9 @@ impl ExpectCCS {
}
impl hs::State for ExpectCCS {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_message(m, &[ContentType::ChangeCipherSpec], &[])
}
fn handle(self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
check_message(&m, &[ContentType::ChangeCipherSpec], &[])?;
fn handle(self: Box<Self>, sess: &mut ServerSessionImpl, _m: Message) -> hs::NextStateOrError {
// CCS should not be received interleaved with fragmented handshake-level
// message.
hs::check_aligned_handshake(sess)?;
@ -335,12 +321,8 @@ impl ExpectFinished {
}
impl hs::State for ExpectFinished {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_handshake_message(m, &[HandshakeType::Finished])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
let finished = extract_handshake!(m, HandshakePayload::Finished).unwrap();
let finished = require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
hs::check_aligned_handshake(sess)?;
@ -397,11 +379,8 @@ impl ExpectTraffic {
}
impl hs::State for ExpectTraffic {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_message(m, &[ContentType::ApplicationData], &[])
}
fn handle(self: Box<Self>, sess: &mut ServerSessionImpl, mut m: Message) -> hs::NextStateOrError {
check_message(&m, &[ContentType::ApplicationData], &[])?;
sess.common.take_received_plaintext(m.take_opaque_payload().unwrap());
Ok(self)
}

View File

@ -41,7 +41,7 @@ use crate::suites;
#[cfg(feature = "logging")]
use crate::log::{warn, trace, debug};
use crate::error::TLSError;
use crate::handshake::{check_handshake_message, check_message};
use crate::check::check_message;
#[cfg(feature = "quic")]
use crate::{
quic,
@ -467,7 +467,7 @@ impl CompleteClientHelloHandling {
sni: Option<webpki::DNSName>,
mut server_key: sign::CertifiedKey,
chm: &Message) -> hs::NextStateOrError {
let client_hello = extract_handshake!(chm, HandshakePayload::ClientHello).unwrap();
let client_hello = require_handshake_msg!(chm, HandshakeType::ClientHello, HandshakePayload::ClientHello)?;
if client_hello.compression_methods.len() != 1 {
return Err(hs::illegal_param(sess, "client offered wrong compressions"));
@ -636,12 +636,8 @@ impl ExpectCertificate {
}
impl hs::State for ExpectCertificate {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_handshake_message(m, &[HandshakeType::Certificate])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
let certp = extract_handshake!(m, HandshakePayload::CertificateTLS13).unwrap();
let certp = require_handshake_msg!(m, HandshakeType::Certificate, HandshakePayload::CertificateTLS13)?;
self.handshake.transcript.add_message(&m);
// We don't send any CertificateRequest extensions, so any extensions
@ -700,13 +696,9 @@ impl ExpectCertificateVerify {
}
impl hs::State for ExpectCertificateVerify {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_handshake_message(m, &[HandshakeType::CertificateVerify])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
let rc = {
let sig = extract_handshake!(m, HandshakePayload::CertificateVerify).unwrap();
let sig = require_handshake_msg!(m, HandshakeType::CertificateVerify, HandshakePayload::CertificateVerify)?;
let handshake_hash = self.handshake.transcript.get_current_hash();
self.handshake.transcript.abandon_client_auth();
let certs = &self.client_cert.cert_chain;
@ -849,12 +841,8 @@ impl ExpectFinished {
}
impl hs::State for ExpectFinished {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_handshake_message(m, &[HandshakeType::Finished])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
let finished = extract_handshake!(m, HandshakePayload::Finished).unwrap();
let finished = require_handshake_msg!(m, HandshakeType::Finished, HandshakePayload::Finished)?;
let handshake_hash = self.handshake.transcript.get_current_hash();
let expect_verify_data = self.key_schedule.sign_client_finish(&handshake_hash);
@ -919,9 +907,7 @@ impl ExpectTraffic {
Ok(())
}
fn handle_key_update(&mut self, sess: &mut ServerSessionImpl, m: Message) -> Result<(), TLSError> {
let kur = extract_handshake!(m, HandshakePayload::KeyUpdate).unwrap();
fn handle_key_update(&mut self, sess: &mut ServerSessionImpl, kur: &KeyUpdateRequest) -> Result<(), TLSError> {
#[cfg(feature = "quic")]
{
if let Protocol::Quic = sess.common.protocol {
@ -955,17 +941,15 @@ impl ExpectTraffic {
}
impl hs::State for ExpectTraffic {
fn check_message(&self, m: &Message) -> hs::CheckResult {
check_message(m,
&[ContentType::ApplicationData, ContentType::Handshake],
&[HandshakeType::KeyUpdate])
}
fn handle(mut self: Box<Self>, sess: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
if m.is_content_type(ContentType::ApplicationData) {
self.handle_traffic(sess, m)?;
} else if m.is_handshake_type(HandshakeType::KeyUpdate) {
self.handle_key_update(sess, m)?;
} else if let Ok(key_update) = require_handshake_msg!(m, HandshakeType::KeyUpdate, HandshakePayload::KeyUpdate) {
self.handle_key_update(sess, key_update)?;
} else {
check_message(&m,
&[ContentType::ApplicationData, ContentType::Handshake],
&[HandshakeType::KeyUpdate])?;
}
Ok(self)
@ -997,14 +981,9 @@ pub struct ExpectQUICTraffic {
#[cfg(feature = "quic")]
impl hs::State for ExpectQUICTraffic {
fn check_message(&self, m: &Message) -> hs::CheckResult {
Err(TLSError::InappropriateMessage {
expect_types: Vec::new(),
got_type: m.typ,
})
}
fn handle(self: Box<Self>, _: &mut ServerSessionImpl, _: Message) -> hs::NextStateOrError {
unreachable!("check_message always fails");
fn handle(self: Box<Self>, _: &mut ServerSessionImpl, m: Message) -> hs::NextStateOrError {
// reject all messages
check_message(&m, &[], &[])?;
unreachable!();
}
}