Split BorrowedPlainMessage in inbound and outbound types

Signed-off-by: Eloi DEMOLIS <eloi.demolis@clever-cloud.com>
This commit is contained in:
Eloi DEMOLIS 2024-02-01 17:53:33 +01:00 committed by Dirkjan Ochtman
parent 1cdb10f8b4
commit 9af53f25f9
14 changed files with 183 additions and 123 deletions

View File

@ -86,7 +86,7 @@ struct Tls13Cipher(chacha20poly1305::ChaCha20Poly1305, cipher::Iv);
impl cipher::MessageEncrypter for Tls13Cipher {
fn encrypt(
&mut self,
m: cipher::BorrowedPlainMessage,
m: cipher::OutboundMessage,
seq: u64,
) -> Result<cipher::OpaqueMessage, rustls::Error> {
let total_len = self.encrypted_payload_len(m.payload.len());
@ -121,7 +121,7 @@ impl cipher::MessageDecrypter for Tls13Cipher {
&mut self,
mut m: cipher::BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<cipher::BorrowedPlainMessage<'a>, rustls::Error> {
) -> Result<cipher::InboundMessage<'a>, rustls::Error> {
let payload = &mut m.payload;
let nonce = chacha20poly1305::Nonce::from(cipher::Nonce::new(&self.1, seq).0);
let aad = cipher::make_tls13_aad(payload.len());
@ -139,7 +139,7 @@ struct Tls12Cipher(chacha20poly1305::ChaCha20Poly1305, cipher::Iv);
impl cipher::MessageEncrypter for Tls12Cipher {
fn encrypt(
&mut self,
m: cipher::BorrowedPlainMessage,
m: cipher::OutboundMessage,
seq: u64,
) -> Result<cipher::OpaqueMessage, rustls::Error> {
let total_len = self.encrypted_payload_len(m.payload.len());
@ -166,7 +166,7 @@ impl cipher::MessageDecrypter for Tls12Cipher {
&mut self,
mut m: cipher::BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<cipher::BorrowedPlainMessage<'a>, rustls::Error> {
) -> Result<cipher::InboundMessage<'a>, rustls::Error> {
let payload = &m.payload;
let nonce = chacha20poly1305::Nonce::from(cipher::Nonce::new(&self.1, seq).0);
let aad = cipher::make_tls12_aad(
@ -181,7 +181,7 @@ impl cipher::MessageDecrypter for Tls12Cipher {
.decrypt_in_place(&nonce, &aad, &mut BufferAdapter(payload))
.map_err(|_| rustls::Error::DecryptError)?;
Ok(m.into_plain_message())
Ok(m.into_inbound_message())
}
}

View File

@ -7,8 +7,9 @@ use crate::msgs::base::Payload;
use crate::msgs::enums::{AlertLevel, KeyUpdateRequest};
use crate::msgs::fragmenter::MessageFragmenter;
use crate::msgs::handshake::CertificateChain;
use crate::msgs::message::MessagePayload;
use crate::msgs::message::{BorrowedPlainMessage, Message, OpaqueMessage, PlainMessage};
use crate::msgs::message::{
BorrowedPlainMessage, Message, MessagePayload, OpaqueMessage, OutboundMessage, PlainMessage,
};
use crate::quic;
use crate::record_layer;
use crate::suites::PartiallyExtractedSecrets;
@ -300,7 +301,7 @@ impl CommonState {
len
}
fn send_single_fragment(&mut self, m: BorrowedPlainMessage) {
fn send_single_fragment(&mut self, m: OutboundMessage) {
// Close connection once we start to run out of
// sequence space.
if self
@ -548,7 +549,7 @@ impl CommonState {
&self,
outgoing_tls: &mut [u8],
opt_msg: Option<&[u8]>,
fragments: impl Iterator<Item = BorrowedPlainMessage<'a>>,
fragments: impl Iterator<Item = OutboundMessage<'a>>,
) -> Result<(), EncryptError> {
let mut required_size = 0;
if let Some(message) = opt_msg {
@ -572,7 +573,7 @@ impl CommonState {
&mut self,
outgoing_tls: &mut [u8],
opt_msg: Option<Vec<u8>>,
fragments: impl Iterator<Item = BorrowedPlainMessage<'a>>,
fragments: impl Iterator<Item = OutboundMessage<'a>>,
) -> usize {
let mut written = 0;
@ -658,7 +659,7 @@ impl CommonState {
let message = PlainMessage::from(Message::build_key_update_notify());
self.queued_key_update_message = Some(
self.record_layer
.encrypt_outgoing(message.borrow())
.encrypt_outgoing(message.borrow_outbound())
.encode(),
);
}

View File

@ -1,12 +1,11 @@
use crate::common_state::{CommonState, Context, IoState, State, DEFAULT_BUFFER_LIMIT};
use crate::crypto::cipher::BorrowedPlainMessage;
use crate::enums::{AlertDescription, ContentType};
use crate::error::{Error, PeerMisbehaved};
#[cfg(feature = "logging")]
use crate::log::trace;
use crate::msgs::deframer::{Deframed, DeframerSliceBuffer, DeframerVecBuffer, MessageDeframer};
use crate::msgs::handshake::Random;
use crate::msgs::message::{Message, MessagePayload};
use crate::msgs::message::{InboundMessage, Message, MessagePayload};
use crate::suites::{ExtractedSecrets, PartiallyExtractedSecrets};
use crate::vecbuf::ChunkVecBuffer;
@ -337,7 +336,7 @@ impl ConnectionRandoms {
// --- Common (to client and server) connection functions ---
fn is_valid_ccs(msg: &BorrowedPlainMessage) -> bool {
fn is_valid_ccs(msg: &InboundMessage) -> bool {
// We passthrough ChangeCipherSpec messages in the deframer without decrypting them.
// Note: this is prior to the record layer, so is unencrypted. See
// third paragraph of section 5 in RFC8446.
@ -778,7 +777,7 @@ impl<Data> ConnectionCore<Data> {
&mut self,
state: Option<&dyn State<Data>>,
deframer_buffer: &mut DeframerSliceBuffer<'b>,
) -> Result<Option<BorrowedPlainMessage<'b>>, Error> {
) -> Result<Option<InboundMessage<'b>>, Error> {
match self.message_deframer.pop(
&mut self.common_state.record_layer,
self.common_state.negotiated_version,
@ -833,7 +832,7 @@ impl<Data> ConnectionCore<Data> {
fn process_msg(
&mut self,
msg: BorrowedPlainMessage,
msg: InboundMessage,
state: Box<dyn State<Data>>,
sendable_plaintext: Option<&mut ChunkVecBuffer>,
) -> Result<Box<dyn State<Data>>, Error> {

View File

@ -7,7 +7,7 @@ use crate::crypto::{ActiveKeyExchange, KeyExchangeAlgorithm};
use crate::enums::{CipherSuite, SignatureScheme};
use crate::error::Error;
use crate::msgs::fragmenter::MAX_FRAGMENT_LEN;
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage};
use crate::msgs::message::{InboundMessage, OpaqueMessage, OutboundMessage};
use crate::suites::{CipherSuiteCommon, ConnectionTrafficSecrets, SupportedCipherSuite};
use crate::tls12::Tls12CipherSuite;
use crate::version::TLS12;
@ -265,7 +265,7 @@ impl MessageDecrypter for GcmMessageDecrypter {
&mut self,
mut msg: BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
) -> Result<InboundMessage<'a>, Error> {
let payload = &msg.payload;
if payload.len() < GCM_OVERHEAD {
return Err(Error::DecryptError);
@ -297,12 +297,12 @@ impl MessageDecrypter for GcmMessageDecrypter {
}
payload.truncate(plain_len);
Ok(msg.into_plain_message())
Ok(msg.into_inbound_message())
}
}
impl MessageEncrypter for GcmMessageEncrypter {
fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, seq).0);
let aad = aead::Aad::from(make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len()));
@ -347,7 +347,7 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
&mut self,
mut msg: BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
) -> Result<InboundMessage<'a>, Error> {
let payload = &msg.payload;
if payload.len() < CHACHAPOLY1305_OVERHEAD {
@ -374,12 +374,12 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
}
payload.truncate(plain_len);
Ok(msg.into_plain_message())
Ok(msg.into_inbound_message())
}
}
impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter {
fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.enc_offset, seq).0);
let aad = aead::Aad::from(make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len()));

View File

@ -10,7 +10,7 @@ use crate::crypto::tls13::{Hkdf, HkdfExpander, OkmBlock, OutputLengthError};
use crate::enums::{CipherSuite, ContentType, ProtocolVersion};
use crate::error::Error;
use crate::msgs::codec::Codec;
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage};
use crate::msgs::message::{InboundMessage, OpaqueMessage, OutboundMessage};
use crate::suites::{CipherSuiteCommon, ConnectionTrafficSecrets, SupportedCipherSuite};
use crate::tls13::Tls13CipherSuite;
@ -220,7 +220,7 @@ struct AeadMessageDecrypter {
}
impl MessageEncrypter for AeadMessageEncrypter {
fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let total_len = self.encrypted_payload_len(msg.payload.len());
let mut payload = Vec::with_capacity(total_len);
payload.extend_from_slice(msg.payload);
@ -251,7 +251,7 @@ impl MessageDecrypter for AeadMessageDecrypter {
&mut self,
mut msg: BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
) -> Result<InboundMessage<'a>, Error> {
let payload = &mut msg.payload;
if payload.len() < self.dec_key.algorithm().tag_len() {
return Err(Error::DecryptError);
@ -276,7 +276,7 @@ struct GcmMessageEncrypter {
}
impl MessageEncrypter for GcmMessageEncrypter {
fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let total_len = msg.payload.len() + 1 + self.enc_key.algorithm().tag_len();
let mut payload = Vec::with_capacity(total_len);
payload.extend_from_slice(msg.payload);
@ -310,7 +310,7 @@ impl MessageDecrypter for GcmMessageDecrypter {
&mut self,
mut msg: BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
) -> Result<InboundMessage<'a>, Error> {
let payload = &mut msg.payload;
if payload.len() < self.dec_key.algorithm().tag_len() {
return Err(Error::DecryptError);

View File

@ -8,7 +8,7 @@ use crate::error::Error;
pub use crate::msgs::base::BorrowedPayload;
use crate::msgs::codec;
pub use crate::msgs::message::{
BorrowedOpaqueMessage, BorrowedPlainMessage, OpaqueMessage, PlainMessage,
BorrowedOpaqueMessage, InboundMessage, OpaqueMessage, OutboundMessage, PlainMessage,
};
use crate::suites::ConnectionTrafficSecrets;
@ -141,14 +141,14 @@ pub trait MessageDecrypter: Send + Sync {
&mut self,
msg: BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error>;
) -> Result<InboundMessage<'a>, Error>;
}
/// Objects with this trait can encrypt TLS messages.
pub trait MessageEncrypter: Send + Sync {
/// Encrypt the given TLS message `msg`, using the sequence number
/// `seq which can be used to derive a unique [`Nonce`].
fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error>;
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error>;
/// Return the length of the ciphertext that results from encrypting plaintext of
/// length `payload_len`
@ -318,7 +318,7 @@ impl From<[u8; Self::MAX_LEN]> for AeadKey {
struct InvalidMessageEncrypter {}
impl MessageEncrypter for InvalidMessageEncrypter {
fn encrypt(&mut self, _m: BorrowedPlainMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
fn encrypt(&mut self, _m: OutboundMessage, _seq: u64) -> Result<OpaqueMessage, Error> {
Err(Error::EncryptError)
}
@ -335,7 +335,7 @@ impl MessageDecrypter for InvalidMessageDecrypter {
&mut self,
_m: BorrowedOpaqueMessage<'a>,
_seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
) -> Result<InboundMessage<'a>, Error> {
Err(Error::DecryptError)
}
}

View File

@ -7,7 +7,7 @@ use crate::crypto::KeyExchangeAlgorithm;
use crate::enums::{CipherSuite, SignatureScheme};
use crate::error::Error;
use crate::msgs::fragmenter::MAX_FRAGMENT_LEN;
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage};
use crate::msgs::message::{InboundMessage, OpaqueMessage, OutboundMessage};
use crate::suites::{CipherSuiteCommon, ConnectionTrafficSecrets, SupportedCipherSuite};
use crate::tls12::Tls12CipherSuite;
@ -249,7 +249,7 @@ impl MessageDecrypter for GcmMessageDecrypter {
&mut self,
mut msg: BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
) -> Result<InboundMessage<'a>, Error> {
let payload = &msg.payload;
if payload.len() < GCM_OVERHEAD {
return Err(Error::DecryptError);
@ -281,12 +281,12 @@ impl MessageDecrypter for GcmMessageDecrypter {
}
payload.truncate(plain_len);
Ok(msg.into_plain_message())
Ok(msg.into_inbound_message())
}
}
impl MessageEncrypter for GcmMessageEncrypter {
fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.iv, seq).0);
let aad = aead::Aad::from(make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len()));
@ -331,7 +331,7 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
&mut self,
mut msg: BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
) -> Result<InboundMessage<'a>, Error> {
let payload = &msg.payload;
if payload.len() < CHACHAPOLY1305_OVERHEAD {
@ -358,12 +358,12 @@ impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
}
payload.truncate(plain_len);
Ok(msg.into_plain_message())
Ok(msg.into_inbound_message())
}
}
impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter {
fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let nonce = aead::Nonce::assume_unique_for_key(Nonce::new(&self.enc_offset, seq).0);
let aad = aead::Aad::from(make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len()));

View File

@ -10,7 +10,7 @@ use crate::crypto::tls13::{Hkdf, HkdfExpander, OkmBlock, OutputLengthError};
use crate::enums::{CipherSuite, ContentType, ProtocolVersion};
use crate::error::Error;
use crate::msgs::codec::Codec;
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage};
use crate::msgs::message::{InboundMessage, OpaqueMessage, OutboundMessage};
use crate::suites::{CipherSuiteCommon, ConnectionTrafficSecrets, SupportedCipherSuite};
use crate::tls13::Tls13CipherSuite;
@ -192,7 +192,7 @@ struct Tls13MessageDecrypter {
}
impl MessageEncrypter for Tls13MessageEncrypter {
fn encrypt(&mut self, msg: BorrowedPlainMessage, seq: u64) -> Result<OpaqueMessage, Error> {
fn encrypt(&mut self, msg: OutboundMessage, seq: u64) -> Result<OpaqueMessage, Error> {
let total_len = self.encrypted_payload_len(msg.payload.len());
let mut payload = Vec::with_capacity(total_len);
payload.extend_from_slice(msg.payload);
@ -223,7 +223,7 @@ impl MessageDecrypter for Tls13MessageDecrypter {
&mut self,
mut msg: BorrowedOpaqueMessage<'a>,
seq: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
) -> Result<InboundMessage<'a>, Error> {
let payload = &mut msg.payload;
if payload.len() < self.dec_key.algorithm().tag_len() {
return Err(Error::DecryptError);

View File

@ -482,7 +482,9 @@ pub mod internal {
};
}
pub mod message {
pub use crate::msgs::message::{Message, MessagePayload, OpaqueMessage, PlainMessage};
pub use crate::msgs::message::{
BorrowedPlainMessage, Message, MessagePayload, OpaqueMessage, PlainMessage,
};
}
pub mod persist {
pub use crate::msgs::persist::ServerSessionValue;

View File

@ -4,11 +4,10 @@ use core::slice::SliceIndex;
use std::io;
use super::codec::Codec;
use super::message::{BorrowedOpaqueMessage, BorrowedPlainMessage};
use crate::enums::{ContentType, ProtocolVersion};
use crate::error::{Error, InvalidMessage, PeerMisbehaved};
use crate::msgs::codec;
use crate::msgs::message::{MessageError, OpaqueMessage};
use crate::msgs::message::{BorrowedOpaqueMessage, InboundMessage, MessageError, OpaqueMessage};
use crate::record_layer::{Decrypted, RecordLayer};
/// This deframer works to reconstruct TLS messages from a stream of arbitrary-sized reads.
@ -113,13 +112,13 @@ impl MessageDeframer {
version,
payload,
} = m;
let raw_payload = RawSlice::from(&*payload);
let raw_payload_slice = RawSlice::from(&*payload);
// This is unencrypted. We check the contents later.
buffer.queue_discard(end);
let message = BorrowedPlainMessage {
let message = InboundMessage {
typ,
version,
payload: buffer.take(raw_payload),
payload: buffer.take(raw_payload_slice),
};
return Ok(Some(Deframed {
want_close_before_decrypt: false,
@ -130,14 +129,19 @@ impl MessageDeframer {
}
// Decrypt the encrypted message (if necessary).
let msg = match record_layer.decrypt_incoming(m) {
let (typ, version, plain_payload_slice) = match record_layer.decrypt_incoming(m) {
Ok(Some(decrypted)) => {
let Decrypted {
want_close_before_decrypt,
plaintext,
plaintext:
InboundMessage {
typ,
version,
payload,
},
} = decrypted;
debug_assert!(!want_close_before_decrypt);
plaintext
(typ, version, RawSlice::from(payload))
}
// This was rejected early data, discard it. If we currently have a handshake
// payload in progress, this counts as interleaved, so we error out.
@ -153,7 +157,7 @@ impl MessageDeframer {
Err(e) => return Err(e),
};
if self.joining_hs.is_some() && msg.typ != ContentType::Handshake {
if self.joining_hs.is_some() && typ != ContentType::Handshake {
// "Handshake messages MUST NOT be interleaved with other record
// types. That is, if a handshake message is split over two or more
// records, there MUST NOT be any other records between them."
@ -162,19 +166,12 @@ impl MessageDeframer {
}
// If it's not a handshake message, just return it -- no joining necessary.
if msg.typ != ContentType::Handshake {
let BorrowedPlainMessage {
typ,
version,
payload,
} = msg;
let raw_payload = RawSlice::from(payload);
let end = start + rd.used();
if typ != ContentType::Handshake {
buffer.queue_discard(end);
let message = BorrowedPlainMessage {
let message = InboundMessage {
typ,
version,
payload: buffer.take(raw_payload),
payload: buffer.take(plain_payload_slice),
};
return Ok(Some(Deframed {
want_close_before_decrypt: false,
@ -186,9 +183,7 @@ impl MessageDeframer {
// If we don't know the payload size yet or if the payload size is larger
// than the currently buffered payload, we need to wait for more data.
let raw = RawSlice::from(msg.payload);
let version = msg.version;
let src = buffer.raw_slice_to_filled_range(raw);
let src = buffer.raw_slice_to_filled_range(plain_payload_slice);
match self.append_hs(version, InternalPayload(src), end, buffer)? {
HandshakePayloadState::Blocked => return Ok(None),
HandshakePayloadState::Complete(len) => break len,
@ -221,7 +216,7 @@ impl MessageDeframer {
buffer.queue_discard(end);
}
let message = BorrowedPlainMessage {
let message = InboundMessage {
typ,
version,
payload: buffer.take(raw_payload),
@ -686,7 +681,7 @@ pub struct Deframed<'a> {
pub(crate) want_close_before_decrypt: bool,
pub(crate) aligned: bool,
pub(crate) trial_decryption_finished: bool,
pub message: BorrowedPlainMessage<'a>,
pub message: InboundMessage<'a>,
}
const HEADER_SIZE: usize = 1 + 3;
@ -703,7 +698,7 @@ mod tests {
use std::io;
use crate::crypto::cipher::PlainMessage;
use crate::msgs::message::Message;
use crate::msgs::message::{BorrowedPlainMessage, Message};
use super::*;

View File

@ -1,6 +1,6 @@
use crate::enums::ContentType;
use crate::enums::ProtocolVersion;
use crate::msgs::message::{BorrowedPlainMessage, PlainMessage};
use crate::msgs::message::{OutboundMessage, PlainMessage};
use crate::Error;
pub(crate) const MAX_FRAGMENT_LEN: usize = 16384;
pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2;
@ -26,7 +26,7 @@ impl MessageFragmenter {
pub fn fragment_message<'a>(
&self,
msg: &'a PlainMessage,
) -> impl Iterator<Item = BorrowedPlainMessage<'a>> + 'a {
) -> impl Iterator<Item = OutboundMessage<'a>> + 'a {
self.fragment_slice(msg.typ, msg.version, msg.payload.bytes())
}
@ -37,13 +37,13 @@ impl MessageFragmenter {
typ: ContentType,
version: ProtocolVersion,
payload: &'a [u8],
) -> impl ExactSizeIterator<Item = BorrowedPlainMessage<'a>> {
) -> impl ExactSizeIterator<Item = OutboundMessage<'a>> {
payload
.chunks(self.max_frag)
.map(move |c| BorrowedPlainMessage {
.map(move |payload| OutboundMessage {
typ,
version,
payload: c,
payload,
})
}
@ -71,10 +71,10 @@ mod tests {
use crate::enums::ContentType;
use crate::enums::ProtocolVersion;
use crate::msgs::base::Payload;
use crate::msgs::message::{BorrowedPlainMessage, PlainMessage};
use crate::msgs::message::{BorrowedPlainMessage, OutboundMessage, PlainMessage};
fn msg_eq(
m: &BorrowedPlainMessage,
m: &OutboundMessage,
total_len: usize,
typ: &ContentType,
version: &ProtocolVersion,

View File

@ -3,18 +3,15 @@ use crate::enums::{AlertDescription, ContentType, HandshakeType};
use crate::error::{Error, InvalidMessage, PeerMisbehaved};
use crate::internal::record_layer::RecordLayer;
use crate::msgs::alert::AlertMessagePayload;
use crate::msgs::base::Payload;
use crate::msgs::base::{BorrowedPayload, Payload};
use crate::msgs::ccs::ChangeCipherSpecPayload;
use crate::msgs::codec::{Codec, Reader};
use crate::msgs::codec::{Codec, Reader, ReaderMut};
use crate::msgs::enums::AlertLevel;
use crate::msgs::fragmenter::MAX_FRAGMENT_LEN;
use crate::msgs::handshake::HandshakeMessagePayload;
use alloc::vec::Vec;
use super::base::BorrowedPayload;
use super::codec::ReaderMut;
#[derive(Debug)]
pub enum MessagePayload<'a> {
Alert(AlertMessagePayload),
@ -197,11 +194,11 @@ pub struct BorrowedOpaqueMessage<'a> {
}
impl<'a> BorrowedOpaqueMessage<'a> {
/// Force conversion into a plaintext message.
/// Force conversion into an inbound plaintext message.
///
/// See [`OpaqueMessage::into_plain_message`] for more information
pub fn into_plain_message(self) -> BorrowedPlainMessage<'a> {
BorrowedPlainMessage {
pub fn into_inbound_message(self) -> InboundMessage<'a> {
InboundMessage {
typ: self.typ,
version: self.version,
payload: self.payload.into_inner(),
@ -212,7 +209,7 @@ impl<'a> BorrowedOpaqueMessage<'a> {
///
/// Returns an error if the message (pre-unpadding) is too long, or the padding is invalid,
/// or the message (post-unpadding) is too long.
pub fn into_tls13_unpadded_message(mut self) -> Result<BorrowedPlainMessage<'a>, Error> {
pub fn into_tls13_unpadded_message(mut self) -> Result<InboundMessage<'a>, Error> {
let payload = &mut self.payload;
if payload.len() > MAX_FRAGMENT_LEN + 1 {
@ -229,7 +226,7 @@ impl<'a> BorrowedOpaqueMessage<'a> {
}
self.version = ProtocolVersion::TLSv1_3;
Ok(self.into_plain_message())
Ok(self.into_inbound_message())
}
pub(crate) fn read(r: &mut ReaderMut<'a>) -> Result<Self, MessageError> {
@ -338,8 +335,16 @@ impl PlainMessage {
}
}
pub fn borrow(&self) -> BorrowedPlainMessage<'_> {
BorrowedPlainMessage {
pub fn borrow_inbound(&self) -> InboundMessage<'_> {
InboundMessage {
version: self.version,
typ: self.typ,
payload: self.payload.bytes(),
}
}
pub fn borrow_outbound(&self) -> OutboundMessage<'_> {
OutboundMessage {
version: self.version,
typ: self.typ,
payload: self.payload.bytes(),
@ -406,10 +411,10 @@ impl TryFrom<PlainMessage> for Message<'static> {
///
/// A [`PlainMessage`] must contain plaintext content. Encrypted content should be stored in an
/// [`OpaqueMessage`] and decrypted before being stored into a [`PlainMessage`].
impl<'a> TryFrom<BorrowedPlainMessage<'a>> for Message<'a> {
impl<'a> TryFrom<InboundMessage<'a>> for Message<'a> {
type Error = Error;
fn try_from(plain: BorrowedPlainMessage<'a>) -> Result<Self, Self::Error> {
fn try_from(plain: InboundMessage<'a>) -> Result<Self, Self::Error> {
Ok(Self {
version: plain.version,
payload: MessagePayload::new(plain.typ, plain.version, plain.payload)?,
@ -419,44 +424,100 @@ impl<'a> TryFrom<BorrowedPlainMessage<'a>> for Message<'a> {
/// A TLS frame, named TLSPlaintext in the standard.
///
/// This type differs from `OpaqueMessage` because it borrows
/// its payload. You can make a `OpaqueMessage` from an
/// `BorrowMessage`, but this involves a copy.
/// This type borrows its decrypted payload from a `MessageDeframer`.
/// You can make a `OpaqueMessage` from an `InboundMessage`,
/// but this involves a copy.
///
/// This type also cannot decode its internals and
/// cannot be read/encoded; only `OpaqueMessage` can do that.
#[derive(Debug)]
pub struct BorrowedPlainMessage<'a> {
pub struct InboundMessage<'a> {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: &'a [u8],
}
impl<'a> BorrowedPlainMessage<'a> {
pub fn to_unencrypted_opaque(&self) -> OpaqueMessage {
OpaqueMessage {
version: self.version,
typ: self.typ,
payload: Payload::Owned(self.payload.to_vec()),
}
impl BorrowedPlainMessage for InboundMessage<'_> {
fn payload_to_vec(&self) -> Vec<u8> {
self.payload.to_vec()
}
pub fn encoded_len(&self, record_layer: &RecordLayer) -> usize {
OpaqueMessage::HEADER_SIZE as usize + record_layer.encrypted_len(self.payload.len())
fn payload_len(&self) -> usize {
self.payload.len()
}
pub fn into_owned(self) -> PlainMessage {
let Self {
typ,
version,
payload,
} = self;
fn typ(&self) -> ContentType {
self.typ
}
fn version(&self) -> ProtocolVersion {
self.version
}
}
/// A TLS frame, named TLSPlaintext in the standard.
///
/// This type borrows its "to be encrypted" data from the client.
/// You can make a `OpaqueMessage` from an `OutboundMessage`,
/// but this involves a copy.
///
/// This type also cannot decode its internals and
/// cannot be read/encoded; only `OpaqueMessage` can do that.
#[derive(Debug)]
pub struct OutboundMessage<'a> {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: &'a [u8],
}
impl BorrowedPlainMessage for OutboundMessage<'_> {
fn payload_to_vec(&self) -> Vec<u8> {
self.payload.to_vec()
}
fn payload_len(&self) -> usize {
self.payload.len()
}
fn typ(&self) -> ContentType {
self.typ
}
fn version(&self) -> ProtocolVersion {
self.version
}
}
/// Abstract both inbound and outbound variants of a plaintext message
pub trait BorrowedPlainMessage: Sized {
fn into_owned(self) -> PlainMessage {
PlainMessage {
typ,
version,
payload: Payload::new(payload),
version: self.version(),
typ: self.typ(),
payload: Payload::Owned(self.payload_to_vec()),
}
}
fn to_unencrypted_opaque(&self) -> OpaqueMessage {
OpaqueMessage {
version: self.version(),
typ: self.typ(),
payload: Payload::Owned(self.payload_to_vec()),
}
}
fn encoded_len(&self, record_layer: &RecordLayer) -> usize {
OpaqueMessage::HEADER_SIZE as usize + record_layer.encrypted_len(self.payload_len())
}
fn payload_to_vec(&self) -> Vec<u8>;
fn payload_len(&self) -> usize;
fn typ(&self) -> ContentType;
fn version(&self) -> ProtocolVersion;
}
#[derive(Debug)]

View File

@ -2,7 +2,7 @@ use core::num::NonZeroU64;
use crate::crypto::cipher::{BorrowedOpaqueMessage, MessageDecrypter, MessageEncrypter};
use crate::error::Error;
use crate::msgs::message::{BorrowedPlainMessage, OpaqueMessage};
use crate::msgs::message::{InboundMessage, OpaqueMessage, OutboundMessage};
#[cfg(feature = "logging")]
use crate::log::trace;
@ -67,7 +67,7 @@ impl RecordLayer {
if self.decrypt_state != DirectionState::Active {
return Ok(Some(Decrypted {
want_close_before_decrypt: false,
plaintext: encr.into_plain_message(),
plaintext: encr.into_inbound_message(),
}));
}
@ -108,7 +108,7 @@ impl RecordLayer {
///
/// `plain` is a TLS message we'd like to send. This function
/// panics if the requisite keying material hasn't been established yet.
pub(crate) fn encrypt_outgoing(&mut self, plain: BorrowedPlainMessage) -> OpaqueMessage {
pub(crate) fn encrypt_outgoing(&mut self, plain: OutboundMessage) -> OpaqueMessage {
debug_assert!(self.encrypt_state == DirectionState::Active);
assert!(!self.encrypt_exhausted());
let seq = self.write_seq;
@ -242,7 +242,7 @@ pub(crate) struct Decrypted<'a> {
/// Whether the peer appears to be getting close to encrypting too many messages with this key.
pub(crate) want_close_before_decrypt: bool,
/// The decrypted message.
pub(crate) plaintext: BorrowedPlainMessage<'a>,
pub(crate) plaintext: InboundMessage<'a>,
}
#[cfg(test)]
@ -259,8 +259,8 @@ mod tests {
&mut self,
m: BorrowedOpaqueMessage<'a>,
_: u64,
) -> Result<BorrowedPlainMessage<'a>, Error> {
Ok(m.into_plain_message())
) -> Result<InboundMessage<'a>, Error> {
Ok(m.into_inbound_message())
}
}

View File

@ -27,7 +27,9 @@ use rustls::internal::msgs::base::Payload;
use rustls::internal::msgs::codec::Codec;
use rustls::internal::msgs::enums::AlertLevel;
use rustls::internal::msgs::handshake::{ClientExtension, HandshakePayload};
use rustls::internal::msgs::message::{Message, MessagePayload, PlainMessage};
use rustls::internal::msgs::message::{
BorrowedPlainMessage, Message, MessagePayload, PlainMessage,
};
use rustls::server::{ClientHello, ParsedCertificate, ResolvesServerCert};
use rustls::SupportedCipherSuite;
use rustls::{
@ -745,11 +747,11 @@ fn test_tls13_valid_early_plaintext_alert() {
// * The negotiated protocol version is TLS 1.3.
server
.read_tls(&mut io::Cursor::new(
<Message as Into<PlainMessage>>::into(Message::build_alert(
PlainMessage::from(Message::build_alert(
AlertLevel::Fatal,
AlertDescription::UnknownCA,
))
.borrow()
.borrow_inbound()
.to_unencrypted_opaque()
.encode(),
))
@ -797,11 +799,11 @@ fn test_tls13_late_plaintext_alert() {
// Inject a plaintext alert from the client. The server should attempt to decrypt this message.
server
.read_tls(&mut io::Cursor::new(
<Message as Into<PlainMessage>>::into(Message::build_alert(
PlainMessage::from(Message::build_alert(
AlertLevel::Fatal,
AlertDescription::UnknownCA,
))
.borrow()
.borrow_inbound()
.to_unencrypted_opaque()
.encode(),
))