add a Borrowed variant to Payload

Co-authored-by: Jorge Aparicio <jorge.aparicio@ferrous-systems.com>
This commit is contained in:
Christian Poveda 2023-12-08 14:24:15 -05:00 committed by Daniel McCarney
parent d2b95ae772
commit d8abdb3e0a
16 changed files with 112 additions and 85 deletions

View File

@ -27,7 +27,7 @@ fuzz_target!(|data: &[u8]| {
Message::try_from(PlainMessage {
typ: msg.typ,
version: msg.version,
payload: Payload(msg.payload.to_vec()),
payload: Payload::Owned(msg.payload.to_vec()),
})
.ok();
}

View File

@ -1009,15 +1009,15 @@ impl State<ClientConnectionData> for ExpectFinished {
// Constant-time verification of this is relatively unimportant: they only
// get one chance. But it can't hurt.
let _fin_verified = match ConstantTimeEq::ct_eq(&expect_verify_data[..], &finished.0).into()
{
true => verify::FinishedMessageVerified::assertion(),
false => {
return Err(cx
.common
.send_fatal_alert(AlertDescription::DecryptError, Error::DecryptError));
}
};
let _fin_verified =
match ConstantTimeEq::ct_eq(&expect_verify_data[..], finished.bytes()).into() {
true => verify::FinishedMessageVerified::assertion(),
false => {
return Err(cx
.common
.send_fatal_alert(AlertDescription::DecryptError, Error::DecryptError));
}
};
// Hash this message too.
st.transcript.add_message(&m);

View File

@ -827,7 +827,8 @@ impl State<ClientConnectionData> for ExpectFinished {
.key_schedule
.sign_server_finish(&handshake_hash);
let fin = match ConstantTimeEq::ct_eq(expect_verify_data.as_ref(), &finished.0).into() {
let fin = match ConstantTimeEq::ct_eq(expect_verify_data.as_ref(), finished.bytes()).into()
{
true => verify::FinishedMessageVerified::assertion(),
false => {
return Err(cx

View File

@ -428,7 +428,8 @@ impl CommonState {
}
pub(crate) fn take_received_plaintext(&mut self, bytes: Payload) {
self.received_plaintext.append(bytes.0);
self.received_plaintext
.append(bytes.into_vec());
}
#[cfg(feature = "tls12")]

View File

@ -340,7 +340,7 @@ fn is_valid_ccs(msg: &PlainMessage) -> bool {
// We passthrough ChangeCipherSpec messages in the deframer without decrypting them.
// Note: this is prior to the record layer, so is unencrypted. See
// third paragraph of section 5 in RFC8446.
msg.typ == ContentType::ChangeCipherSpec && msg.payload.0 == [0x01]
msg.typ == ContentType::ChangeCipherSpec && msg.payload.bytes() == [0x01]
}
/// Interface shared by client and server connections.

View File

@ -36,7 +36,7 @@ impl HandshakeHashBuffer {
pub(crate) fn add_message(&mut self, m: &Message) {
if let MessagePayload::Handshake { encoded, .. } = &m.payload {
self.buffer
.extend_from_slice(&encoded.0);
.extend_from_slice(encoded.bytes());
}
}
@ -98,7 +98,7 @@ impl HandshakeHash {
/// Hash/buffer a handshake message.
pub(crate) fn add_message(&mut self, m: &Message) -> &mut Self {
if let MessagePayload::Handshake { encoded, .. } = &m.payload {
self.update_raw(&encoded.0);
self.update_raw(encoded.bytes());
}
self
}

View File

@ -13,29 +13,47 @@ use super::codec::ReaderMut;
/// An externally length'd payload
#[derive(Clone, Eq, PartialEq)]
pub struct Payload(pub Vec<u8>);
pub enum Payload<'a> {
Borrowed(&'a [u8]),
Owned(Vec<u8>),
}
impl Codec<'_> for Payload {
impl<'a> Codec<'a> for Payload<'static> {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(&self.0);
bytes.extend_from_slice(self.bytes());
}
fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
Ok(Self::read(r))
}
}
impl Payload {
impl<'a> Payload<'a> {
pub fn bytes(&self) -> &[u8] {
match self {
Self::Borrowed(bytes) => bytes,
Self::Owned(bytes) => bytes,
}
}
pub fn into_vec(self) -> Vec<u8> {
match self {
Self::Borrowed(bytes) => bytes.to_vec(),
Self::Owned(bytes) => bytes,
}
}
}
impl Payload<'static> {
pub fn new(bytes: impl Into<Vec<u8>>) -> Self {
Self(bytes.into())
Self::Owned(bytes.into())
}
pub fn empty() -> Self {
Self::new(Vec::new())
Self::Borrowed(&[])
}
pub fn read(r: &mut Reader) -> Self {
Self(r.rest().to_vec())
Self::Owned(r.rest().to_vec())
}
}
@ -106,9 +124,9 @@ impl<'a> Codec<'_> for CertificateDer<'a> {
}
}
impl fmt::Debug for Payload {
impl fmt::Debug for Payload<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
hex(f, &self.0)
hex(f, self.bytes())
}
}

View File

@ -865,7 +865,7 @@ mod tests {
let mut rl = RecordLayer::new();
let m = d.pop_message(&mut rl, None);
assert_eq!(m.typ, ContentType::ApplicationData);
assert_eq!(m.payload.0.len(), 0);
assert_eq!(m.payload.bytes().len(), 0);
assert!(!d.has_pending());
assert!(d.last_error.is_none());
}

View File

@ -27,7 +27,7 @@ impl MessageFragmenter {
&self,
msg: &'a PlainMessage,
) -> impl Iterator<Item = BorrowedPlainMessage<'a>> + 'a {
self.fragment_slice(msg.typ, msg.version, &msg.payload.0)
self.fragment_slice(msg.typ, msg.version, msg.payload.bytes())
}
/// Enqueue borrowed fragments of (version, typ, payload) which

View File

@ -185,7 +185,7 @@ impl SessionId {
#[derive(Clone, Debug, PartialEq)]
pub struct UnknownExtension {
pub(crate) typ: ExtensionType,
pub(crate) payload: Payload,
pub(crate) payload: Payload<'static>,
}
impl UnknownExtension {
@ -214,7 +214,7 @@ impl TlsListElement for SignatureScheme {
#[derive(Clone, Debug)]
pub(crate) enum ServerNamePayload {
HostName(DnsName<'static>),
Unknown(Payload),
Unknown(Payload<'static>),
}
impl ServerNamePayload {
@ -483,7 +483,7 @@ impl Codec<'_> for OcspCertificateStatusRequest {
#[derive(Clone, Debug)]
pub enum CertificateStatusRequest {
Ocsp(OcspCertificateStatusRequest),
Unknown((CertificateStatusType, Payload)),
Unknown((CertificateStatusType, Payload<'static>)),
}
impl Codec<'_> for CertificateStatusRequest {
@ -684,7 +684,7 @@ impl ClientExtension {
#[derive(Clone, Debug)]
pub enum ClientSessionTicket {
Request,
Offer(Payload),
Offer(Payload<'static>),
}
#[derive(Clone, Debug)]
@ -1588,7 +1588,7 @@ impl Codec<'_> for EcdheServerKeyExchange {
#[derive(Debug)]
pub enum ServerKeyExchangePayload {
Ecdhe(EcdheServerKeyExchange),
Unknown(Payload),
Unknown(Payload<'static>),
}
impl Codec<'_> for ServerKeyExchangePayload {
@ -1613,7 +1613,7 @@ impl ServerKeyExchangePayload {
kxa: KeyExchangeAlgorithm,
) -> Option<EcdheServerKeyExchange> {
if let Self::Unknown(ref unk) = *self {
let mut rd = Reader::init(&unk.0);
let mut rd = Reader::init(unk.bytes());
let result = match kxa {
KeyExchangeAlgorithm::ECDHE => EcdheServerKeyExchange::read(&mut rd),
@ -2076,15 +2076,15 @@ pub enum HandshakePayload {
CertificateVerify(DigitallySignedStruct),
ServerHelloDone,
EndOfEarlyData,
ClientKeyExchange(Payload),
ClientKeyExchange(Payload<'static>),
NewSessionTicket(NewSessionTicketPayload),
NewSessionTicketTls13(NewSessionTicketPayloadTls13),
EncryptedExtensions(Vec<ServerExtension>),
KeyUpdate(KeyUpdateRequest),
Finished(Payload),
Finished(Payload<'static>),
CertificateStatus(CertificateStatus),
MessageHash(Payload),
Unknown(Payload),
MessageHash(Payload<'static>),
Unknown(Payload<'static>),
}
impl HandshakePayload {

View File

@ -370,7 +370,7 @@ fn get_sample_clienthellopayload() -> ClientHelloPayload {
ClientExtension::SignatureAlgorithms(vec![SignatureScheme::ECDSA_NISTP256_SHA256]),
ClientExtension::make_sni(&DnsName::try_from("hello").unwrap()),
ClientExtension::SessionTicket(ClientSessionTicket::Request),
ClientExtension::SessionTicket(ClientSessionTicket::Offer(Payload(vec![]))),
ClientExtension::SessionTicket(ClientSessionTicket::Offer(Payload::Borrowed(&[]))),
ClientExtension::Protocols(vec![ProtocolName::from(vec![0])]),
ClientExtension::SupportedVersions(vec![ProtocolVersion::TLSv1_3]),
ClientExtension::KeyShare(vec![KeyShareEntry::new(NamedGroup::X25519, &[1, 2, 3])]),
@ -391,7 +391,7 @@ fn get_sample_clienthellopayload() -> ClientHelloPayload {
ClientExtension::TransportParameters(vec![1, 2, 3]),
ClientExtension::Unknown(UnknownExtension {
typ: ExtensionType::Unknown(12345),
payload: Payload(vec![1, 2, 3]),
payload: Payload::Borrowed(&[1, 2, 3]),
}),
],
}
@ -497,7 +497,7 @@ fn test_client_extension_getter(typ: ExtensionType, getter: fn(&ClientHelloPaylo
chp.extensions = vec![ClientExtension::Unknown(UnknownExtension {
typ,
payload: Payload(vec![]),
payload: Payload::Borrowed(&[]),
})];
assert!(!getter(&chp));
}
@ -612,7 +612,7 @@ fn test_helloretry_extension_getter(typ: ExtensionType, getter: fn(&HelloRetryRe
hrr.extensions = vec![HelloRetryExtension::Unknown(UnknownExtension {
typ,
payload: Payload(vec![]),
payload: Payload::Borrowed(&[]),
})];
assert!(!getter(&hrr));
}
@ -681,7 +681,7 @@ fn test_server_extension_getter(typ: ExtensionType, getter: fn(&ServerHelloPaylo
shp.extensions = vec![ServerExtension::Unknown(UnknownExtension {
typ,
payload: Payload(vec![]),
payload: Payload::Borrowed(&[]),
})];
assert!(!getter(&shp));
}
@ -724,7 +724,7 @@ fn test_cert_extension_getter(typ: ExtensionType, getter: fn(&CertificateEntry)
ce.exts = vec![CertificateExtension::Unknown(UnknownExtension {
typ,
payload: Payload(vec![]),
payload: Payload::Borrowed(&[]),
})];
assert!(!getter(&ce));
}
@ -757,7 +757,7 @@ fn get_sample_serverhellopayload() -> ServerHelloPayload {
ServerExtension::TransportParameters(vec![1, 2, 3]),
ServerExtension::Unknown(UnknownExtension {
typ: ExtensionType::Unknown(12345),
payload: Payload(vec![1, 2, 3]),
payload: Payload::Borrowed(&[1, 2, 3]),
}),
],
}
@ -784,7 +784,7 @@ fn get_sample_helloretryrequest() -> HelloRetryRequest {
HelloRetryExtension::SupportedVersions(ProtocolVersion::TLSv1_2),
HelloRetryExtension::Unknown(UnknownExtension {
typ: ExtensionType::Unknown(12345),
payload: Payload(vec![1, 2, 3]),
payload: Payload::Borrowed(&[1, 2, 3]),
}),
],
}
@ -801,7 +801,7 @@ fn get_sample_certificatepayloadtls13() -> CertificatePayloadTls13 {
}),
CertificateExtension::Unknown(UnknownExtension {
typ: ExtensionType::Unknown(12345),
payload: Payload(vec![1, 2, 3]),
payload: Payload::Borrowed(&[1, 2, 3]),
}),
],
}],
@ -822,7 +822,7 @@ fn get_sample_serverkeyexchangepayload_ecdhe() -> ServerKeyExchangePayload {
}
fn get_sample_serverkeyexchangepayload_unknown() -> ServerKeyExchangePayload {
ServerKeyExchangePayload::Unknown(Payload(vec![1, 2, 3]))
ServerKeyExchangePayload::Unknown(Payload::Borrowed(&[1, 2, 3]))
}
fn get_sample_certificaterequestpayload() -> CertificateRequestPayload {
@ -841,7 +841,7 @@ fn get_sample_certificaterequestpayloadtls13() -> CertificateRequestPayloadTls13
CertReqExtension::AuthorityNames(vec![DistinguishedName::from(vec![1, 2, 3])]),
CertReqExtension::Unknown(UnknownExtension {
typ: ExtensionType::Unknown(12345),
payload: Payload(vec![1, 2, 3]),
payload: Payload::Borrowed(&[1, 2, 3]),
}),
],
}
@ -862,7 +862,7 @@ fn get_sample_newsessionticketpayloadtls13() -> NewSessionTicketPayloadTls13 {
ticket: PayloadU16(vec![4, 5, 6]),
exts: vec![NewSessionTicketExtension::Unknown(UnknownExtension {
typ: ExtensionType::Unknown(12345),
payload: Payload(vec![1, 2, 3]),
payload: Payload::Borrowed(&[1, 2, 3]),
})],
}
}
@ -923,7 +923,7 @@ fn get_all_tls12_handshake_payloads() -> Vec<HandshakeMessagePayload> {
},
HandshakeMessagePayload {
typ: HandshakeType::ClientKeyExchange,
payload: HandshakePayload::ClientKeyExchange(Payload(vec![1, 2, 3])),
payload: HandshakePayload::ClientKeyExchange(Payload::Borrowed(&[1, 2, 3])),
},
HandshakeMessagePayload {
typ: HandshakeType::NewSessionTicket,
@ -943,7 +943,7 @@ fn get_all_tls12_handshake_payloads() -> Vec<HandshakeMessagePayload> {
},
HandshakeMessagePayload {
typ: HandshakeType::Finished,
payload: HandshakePayload::Finished(Payload(vec![1, 2, 3])),
payload: HandshakePayload::Finished(Payload::Borrowed(&[1, 2, 3])),
},
HandshakeMessagePayload {
typ: HandshakeType::CertificateStatus,
@ -951,7 +951,7 @@ fn get_all_tls12_handshake_payloads() -> Vec<HandshakeMessagePayload> {
},
HandshakeMessagePayload {
typ: HandshakeType::Unknown(99),
payload: HandshakePayload::Unknown(Payload(vec![1, 2, 3])),
payload: HandshakePayload::Unknown(Payload::Borrowed(&[1, 2, 3])),
},
]
}
@ -1060,7 +1060,7 @@ fn get_all_tls13_handshake_payloads() -> Vec<HandshakeMessagePayload> {
},
HandshakeMessagePayload {
typ: HandshakeType::ClientKeyExchange,
payload: HandshakePayload::ClientKeyExchange(Payload(vec![1, 2, 3])),
payload: HandshakePayload::ClientKeyExchange(Payload::Borrowed(&[1, 2, 3])),
},
HandshakeMessagePayload {
typ: HandshakeType::NewSessionTicket,
@ -1082,7 +1082,7 @@ fn get_all_tls13_handshake_payloads() -> Vec<HandshakeMessagePayload> {
},
HandshakeMessagePayload {
typ: HandshakeType::Finished,
payload: HandshakePayload::Finished(Payload(vec![1, 2, 3])),
payload: HandshakePayload::Finished(Payload::Borrowed(&[1, 2, 3])),
},
HandshakeMessagePayload {
typ: HandshakeType::CertificateStatus,
@ -1090,7 +1090,7 @@ fn get_all_tls13_handshake_payloads() -> Vec<HandshakeMessagePayload> {
},
HandshakeMessagePayload {
typ: HandshakeType::Unknown(99),
payload: HandshakePayload::Unknown(Payload(vec![1, 2, 3])),
payload: HandshakePayload::Unknown(Payload::Borrowed(&[1, 2, 3])),
},
]
}

View File

@ -20,17 +20,17 @@ pub enum MessagePayload {
Alert(AlertMessagePayload),
Handshake {
parsed: HandshakeMessagePayload,
encoded: Payload,
encoded: Payload<'static>,
},
ChangeCipherSpec(ChangeCipherSpecPayload),
ApplicationData(Payload),
ApplicationData(Payload<'static>),
}
impl MessagePayload {
pub fn encode(&self, bytes: &mut Vec<u8>) {
match self {
Self::Alert(x) => x.encode(bytes),
Self::Handshake { encoded, .. } => bytes.extend(&encoded.0),
Self::Handshake { encoded, .. } => bytes.extend(encoded.bytes()),
Self::ChangeCipherSpec(x) => x.encode(bytes),
Self::ApplicationData(x) => x.encode(bytes),
}
@ -46,9 +46,9 @@ impl MessagePayload {
pub fn new(
typ: ContentType,
vers: ProtocolVersion,
payload: Payload,
payload: Payload<'static>,
) -> Result<Self, InvalidMessage> {
let mut r = Reader::init(&payload.0);
let mut r = Reader::init(payload.bytes());
match typ {
ContentType::ApplicationData => Ok(Self::ApplicationData(payload)),
ContentType::Alert => AlertMessagePayload::read(&mut r).map(MessagePayload::Alert),
@ -90,7 +90,7 @@ impl MessagePayload {
pub struct OpaqueMessage {
pub typ: ContentType,
pub version: ProtocolVersion,
payload: Payload,
payload: Payload<'static>,
}
impl OpaqueMessage {
@ -107,12 +107,15 @@ impl OpaqueMessage {
/// Access the message payload as a slice.
pub fn payload(&self) -> &[u8] {
&self.payload.0
self.payload.bytes()
}
/// Access the message payload as a mutable `Vec<u8>`.
pub fn payload_mut(&mut self) -> &mut Vec<u8> {
&mut self.payload.0
match &mut self.payload {
Payload::Borrowed(_) => unreachable!("due to how constructor works"),
Payload::Owned(bytes) => bytes,
}
}
/// `MessageError` allows callers to distinguish between valid prefixes (might
@ -136,7 +139,7 @@ impl OpaqueMessage {
let mut buf = Vec::new();
self.typ.encode(&mut buf);
self.version.encode(&mut buf);
(self.payload.0.len() as u16).encode(&mut buf);
(self.payload.bytes().len() as u16).encode(&mut buf);
self.payload.encode(&mut buf);
buf
}
@ -291,7 +294,7 @@ impl From<Message> for PlainMessage {
_ => {
let mut buf = Vec::new();
msg.payload.encode(&mut buf);
Payload(buf)
Payload::Owned(buf)
}
};
@ -311,7 +314,7 @@ impl From<Message> for PlainMessage {
pub struct PlainMessage {
pub typ: ContentType,
pub version: ProtocolVersion,
pub payload: Payload,
pub payload: Payload<'static>,
}
impl PlainMessage {
@ -327,7 +330,7 @@ impl PlainMessage {
BorrowedPlainMessage {
version: self.version,
typ: self.typ,
payload: &self.payload.0,
payload: self.payload.bytes(),
}
}
}
@ -402,7 +405,7 @@ impl<'a> BorrowedPlainMessage<'a> {
OpaqueMessage {
version: self.version,
typ: self.typ,
payload: Payload(self.payload.to_vec()),
payload: Payload::Owned(self.payload.to_vec()),
}
}

View File

@ -100,7 +100,7 @@ fn construct_all_types() {
#[test]
fn debug_payload() {
assert_eq!("01020304", format!("{:?}", Payload(vec![1, 2, 3, 4])));
assert_eq!("01020304", format!("{:?}", Payload::new(vec![1, 2, 3, 4])));
assert_eq!("01020304", format!("{:?}", PayloadU8(vec![1, 2, 3, 4])));
assert_eq!("01020304", format!("{:?}", PayloadU16(vec![1, 2, 3, 4])));
assert_eq!("01020304", format!("{:?}", PayloadU24(vec![1, 2, 3, 4])));

View File

@ -846,10 +846,10 @@ impl EarlyDataState {
}
pub(super) fn take_received_plaintext(&mut self, bytes: Payload) -> bool {
let available = bytes.0.len();
let available = bytes.bytes().len();
match self {
Self::Accepted(ref mut received) if received.apply_limit(available) == available => {
received.append(bytes.0);
received.append(bytes.into_vec());
true
}
_ => false,

View File

@ -139,7 +139,10 @@ mod client_hello {
.and_then(|ticket| {
ticket_received = true;
debug!("Ticket received");
let data = self.config.ticketer.decrypt(&ticket.0);
let data = self
.config
.ticketer
.decrypt(ticket.bytes());
if data.is_none() {
debug!("Ticket didn't decrypt");
}
@ -602,7 +605,7 @@ impl State<ServerConnectionData> for ExpectClientKx {
// Complete key agreement, and set up encryption with the
// resulting premaster secret.
let peer_kx_params =
tls12::decode_ecdh_params::<ClientEcdhParams>(cx.common, &client_kx.0)?;
tls12::decode_ecdh_params::<ClientEcdhParams>(cx.common, client_kx.bytes())?;
let secrets = ConnectionSecrets::from_key_exchange(
self.server_kx,
&peer_kx_params.public.0,
@ -859,15 +862,15 @@ impl State<ServerConnectionData> for ExpectFinished {
let vh = self.transcript.get_current_hash();
let expect_verify_data = self.secrets.client_verify_data(&vh);
let _fin_verified = match ConstantTimeEq::ct_eq(&expect_verify_data[..], &finished.0).into()
{
true => verify::FinishedMessageVerified::assertion(),
false => {
return Err(cx
.common
.send_fatal_alert(AlertDescription::DecryptError, Error::DecryptError));
}
};
let _fin_verified =
match ConstantTimeEq::ct_eq(&expect_verify_data[..], finished.bytes()).into() {
true => verify::FinishedMessageVerified::assertion(),
false => {
return Err(cx
.common
.send_fatal_alert(AlertDescription::DecryptError, Error::DecryptError));
}
};
// Save connection, perhaps
if !self.resuming && !self.session_id.is_empty() {

View File

@ -856,8 +856,8 @@ impl State<ServerConnectionData> for ExpectAndSkipRejectedEarlyData {
* up to the configured max_early_data_size."
* (RFC8446, 14.2.10) */
if let MessagePayload::ApplicationData(ref skip_data) = m.payload {
if skip_data.0.len() <= self.skip_data_left {
self.skip_data_left -= skip_data.0.len();
if skip_data.bytes().len() <= self.skip_data_left {
self.skip_data_left -= skip_data.bytes().len();
return Ok(self);
}
}
@ -1158,7 +1158,8 @@ impl State<ServerConnectionData> for ExpectFinished {
.key_schedule
.sign_client_finish(&handshake_hash, cx.common);
let fin = match ConstantTimeEq::ct_eq(expect_verify_data.as_ref(), &finished.0[..]).into() {
let fin = match ConstantTimeEq::ct_eq(expect_verify_data.as_ref(), finished.bytes()).into()
{
true => verify::FinishedMessageVerified::assertion(),
false => {
return Err(cx