rustls/rustls/src/record_layer.rs

314 lines
11 KiB
Rust

use core::num::NonZeroU64;
use crate::crypto::cipher::{BorrowedOpaqueMessage, MessageDecrypter, MessageEncrypter};
use crate::error::Error;
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage, PlainMessage};
#[cfg(feature = "logging")]
use crate::log::trace;
use alloc::boxed::Box;
static SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64;
static SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64;
#[derive(PartialEq)]
enum DirectionState {
/// No keying material.
Invalid,
/// Keying material present, but not yet in use.
Prepared,
/// Keying material in use.
Active,
}
/// Record layer that tracks decryption and encryption keys.
pub struct RecordLayer {
message_encrypter: Box<dyn MessageEncrypter>,
message_decrypter: Box<dyn MessageDecrypter>,
write_seq: u64,
read_seq: u64,
has_decrypted: bool,
encrypt_state: DirectionState,
decrypt_state: DirectionState,
// Message encrypted with other keys may be encountered, so failures
// should be swallowed by the caller. This struct tracks the amount
// of message size this is allowed for.
trial_decryption_len: Option<usize>,
}
impl RecordLayer {
/// Create new record layer with no keys.
pub fn new() -> Self {
Self {
message_encrypter: <dyn MessageEncrypter>::invalid(),
message_decrypter: <dyn MessageDecrypter>::invalid(),
write_seq: 0,
read_seq: 0,
has_decrypted: false,
encrypt_state: DirectionState::Invalid,
decrypt_state: DirectionState::Invalid,
trial_decryption_len: None,
}
}
/// Decrypt a TLS message.
///
/// `encr` is a decoded message allegedly received from the peer.
/// If it can be decrypted, its decryption is returned. Otherwise,
/// an error is returned.
pub(crate) fn decrypt_incoming(
&mut self,
encr: BorrowedOpaqueMessage<'_>,
) -> Result<Option<Decrypted>, Error> {
if self.decrypt_state != DirectionState::Active {
return Ok(Some(Decrypted {
want_close_before_decrypt: false,
plaintext: encr.into_plain_message().into_owned(),
}));
}
// Set to `true` if the peer appears to getting close to encrypting
// too many messages with this key.
//
// Perhaps if we send an alert well before their counter wraps, a
// buggy peer won't make a terrible mistake here?
//
// Note that there's no reason to refuse to decrypt: the security
// failure has already happened.
let want_close_before_decrypt = self.read_seq == SEQ_SOFT_LIMIT;
let encrypted_len = encr.payload.len();
match self
.message_decrypter
.decrypt(encr, self.read_seq)
{
Ok(plaintext) => {
self.read_seq += 1;
if !self.has_decrypted {
self.has_decrypted = true;
}
Ok(Some(Decrypted {
want_close_before_decrypt,
plaintext: plaintext.into_owned(),
}))
}
Err(Error::DecryptError) if self.doing_trial_decryption(encrypted_len) => {
trace!("Dropping undecryptable message after aborted early_data");
Ok(None)
}
Err(err) => Err(err),
}
}
/// Encrypt a TLS message.
///
/// `plain` is a TLS message we'd like to send. This function
/// panics if the requisite keying material hasn't been established yet.
pub(crate) fn encrypt_outgoing(&mut self, plain: BorrowedPlainMessage) -> OpaqueMessage {
debug_assert!(self.encrypt_state == DirectionState::Active);
assert!(!self.encrypt_exhausted());
let seq = self.write_seq;
self.write_seq += 1;
self.message_encrypter
.encrypt(plain, seq)
.unwrap()
}
/// Prepare to use the given `MessageEncrypter` for future message encryption.
/// It is not used until you call `start_encrypting`.
pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
self.message_encrypter = cipher;
self.write_seq = 0;
self.encrypt_state = DirectionState::Prepared;
}
/// Prepare to use the given `MessageDecrypter` for future message decryption.
/// It is not used until you call `start_decrypting`.
pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
self.message_decrypter = cipher;
self.read_seq = 0;
self.decrypt_state = DirectionState::Prepared;
}
/// Start using the `MessageEncrypter` previously provided to the previous
/// call to `prepare_message_encrypter`.
pub(crate) fn start_encrypting(&mut self) {
debug_assert!(self.encrypt_state == DirectionState::Prepared);
self.encrypt_state = DirectionState::Active;
}
/// Start using the `MessageDecrypter` previously provided to the previous
/// call to `prepare_message_decrypter`.
pub(crate) fn start_decrypting(&mut self) {
debug_assert!(self.decrypt_state == DirectionState::Prepared);
self.decrypt_state = DirectionState::Active;
}
/// Set and start using the given `MessageEncrypter` for future outgoing
/// message encryption.
pub(crate) fn set_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
self.prepare_message_encrypter(cipher);
self.start_encrypting();
}
/// Set and start using the given `MessageDecrypter` for future incoming
/// message decryption.
pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
self.prepare_message_decrypter(cipher);
self.start_decrypting();
self.trial_decryption_len = None;
}
/// Set and start using the given `MessageDecrypter` for future incoming
/// message decryption, and enable "trial decryption" mode for when TLS1.3
/// 0-RTT is attempted but rejected by the server.
pub(crate) fn set_message_decrypter_with_trial_decryption(
&mut self,
cipher: Box<dyn MessageDecrypter>,
max_length: usize,
) {
self.prepare_message_decrypter(cipher);
self.start_decrypting();
self.trial_decryption_len = Some(max_length);
}
pub(crate) fn finish_trial_decryption(&mut self) {
self.trial_decryption_len = None;
}
/// Return true if we are getting close to encrypting too many
/// messages with our encryption key.
pub(crate) fn wants_close_before_encrypt(&self) -> bool {
self.write_seq == SEQ_SOFT_LIMIT
}
/// Return true if we outright refuse to do anything with the
/// encryption key.
pub(crate) fn encrypt_exhausted(&self) -> bool {
self.write_seq >= SEQ_HARD_LIMIT
}
pub(crate) fn is_encrypting(&self) -> bool {
self.encrypt_state == DirectionState::Active
}
/// Return true if we have ever decrypted a message. This is used in place
/// of checking the read_seq since that will be reset on key updates.
pub(crate) fn has_decrypted(&self) -> bool {
self.has_decrypted
}
pub(crate) fn write_seq(&self) -> u64 {
self.write_seq
}
/// Returns the number of remaining write sequences
pub(crate) fn remaining_write_seq(&self) -> Option<NonZeroU64> {
SEQ_SOFT_LIMIT
.checked_sub(self.write_seq)
.and_then(NonZeroU64::new)
}
pub(crate) fn read_seq(&self) -> u64 {
self.read_seq
}
pub(crate) fn encrypted_len(&self, payload_len: usize) -> usize {
self.message_encrypter
.encrypted_payload_len(payload_len)
}
fn doing_trial_decryption(&mut self, requested: usize) -> bool {
match self
.trial_decryption_len
.and_then(|value| value.checked_sub(requested))
{
Some(remaining) => {
self.trial_decryption_len = Some(remaining);
true
}
_ => false,
}
}
}
/// Result of decryption.
#[derive(Debug)]
pub(crate) struct Decrypted {
/// Whether the peer appears to be getting close to encrypting too many messages with this key.
pub(crate) want_close_before_decrypt: bool,
/// The decrypted message.
pub(crate) plaintext: PlainMessage,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_has_decrypted() {
use crate::{ContentType, ProtocolVersion};
struct PassThroughDecrypter;
impl MessageDecrypter for PassThroughDecrypter {
fn decrypt<'a>(
&mut self,
m: BorrowedOpaqueMessage<'a>,
_: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
Ok(m.into_plain_message())
}
}
// A record layer starts out invalid, having never decrypted.
let mut record_layer = RecordLayer::new();
assert!(matches!(
record_layer.decrypt_state,
DirectionState::Invalid
));
assert_eq!(record_layer.read_seq, 0);
assert!(!record_layer.has_decrypted());
// Preparing the record layer should update the decrypt state, but shouldn't affect whether it
// has decrypted.
record_layer.prepare_message_decrypter(Box::new(PassThroughDecrypter));
assert!(matches!(
record_layer.decrypt_state,
DirectionState::Prepared
));
assert_eq!(record_layer.read_seq, 0);
assert!(!record_layer.has_decrypted());
// Starting decryption should update the decrypt state, but not affect whether it has decrypted.
record_layer.start_decrypting();
assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
assert_eq!(record_layer.read_seq, 0);
assert!(!record_layer.has_decrypted());
// Decrypting a message should update the read_seq and track that we have now performed
// a decryption.
let mut msg = OpaqueMessage::new(
ContentType::Handshake,
ProtocolVersion::TLSv1_2,
vec![0xC0, 0xFF, 0xEE],
);
record_layer
.decrypt_incoming(msg.borrow())
.unwrap();
assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
assert_eq!(record_layer.read_seq, 1);
assert!(record_layer.has_decrypted());
// Resetting the record layer message decrypter (as if a key update occurred) should reset
// the read_seq number, but not our knowledge of whether we have decrypted previously.
record_layer.set_message_decrypter(Box::new(PassThroughDecrypter));
assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
assert_eq!(record_layer.read_seq, 0);
assert!(record_layer.has_decrypted());
}
}