diff --git a/rustls/src/msgs/handshake.rs b/rustls/src/msgs/handshake.rs index b99f285f..9ddea086 100644 --- a/rustls/src/msgs/handshake.rs +++ b/rustls/src/msgs/handshake.rs @@ -301,7 +301,7 @@ impl Codec for ServerName { let payload = match typ { ServerNameType::HostName => ServerNamePayload::read_hostname(r)?, - _ => ServerNamePayload::Unknown(Payload::read(r)?), + _ => ServerNamePayload::Unknown(Payload::read(r).unwrap()), }; Some(ServerName { @@ -654,7 +654,8 @@ impl Codec for ClientExtension { } ExtensionType::SessionTicket => { if sub.any_left() { - ClientExtension::SessionTicketOffer(Payload::read(&mut sub)?) + let contents = Payload::read(&mut sub).unwrap(); + ClientExtension::SessionTicketOffer(contents) } else { ClientExtension::SessionTicketRequest } @@ -2224,7 +2225,7 @@ impl HandshakeMessagePayload { HandshakePayload::ServerHelloDone } HandshakeType::ClientKeyExchange => { - HandshakePayload::ClientKeyExchange(Payload::read(&mut sub)?) + HandshakePayload::ClientKeyExchange(Payload::read(&mut sub).unwrap()) } HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => { let p = CertificateRequestPayloadTLS13::read(&mut sub)?; @@ -2252,7 +2253,7 @@ impl HandshakeMessagePayload { HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?) } HandshakeType::Finished => { - HandshakePayload::Finished(Payload::read(&mut sub)?) + HandshakePayload::Finished(Payload::read(&mut sub).unwrap()) } HandshakeType::CertificateStatus => { HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?) @@ -2265,7 +2266,7 @@ impl HandshakeMessagePayload { // not legal on wire return None; } - _ => HandshakePayload::Unknown(Payload::read(&mut sub)?), + _ => HandshakePayload::Unknown(Payload::read(&mut sub).unwrap()), }; if sub.any_left() { diff --git a/rustls/src/msgs/handshake_test.rs b/rustls/src/msgs/handshake_test.rs index 4c246b86..87bd3245 100644 --- a/rustls/src/msgs/handshake_test.rs +++ b/rustls/src/msgs/handshake_test.rs @@ -370,10 +370,16 @@ fn get_sample_clienthellopayload() -> ClientHelloPayload { ClientExtension::SupportedVersions(vec![ ProtocolVersion::TLSv1_3 ]), ClientExtension::KeyShare(vec![ KeyShareEntry::new(NamedGroup::X25519, &[1, 2, 3]) ]), ClientExtension::PresharedKeyModes(vec![ PSKKeyExchangeMode::PSK_DHE_KE ]), - ClientExtension::PresharedKey( - PresharedKeyOffer::new(PresharedKeyIdentity::new(vec![3, 4, 5], 123456), - vec![1, 2, 3]) - ), + ClientExtension::PresharedKey(PresharedKeyOffer { + identities: vec![ + PresharedKeyIdentity::new(vec![3, 4, 5], 123456), + PresharedKeyIdentity::new(vec![6, 7, 8], 7891011), + ], + binders: vec![ + PresharedKeyBinder::new(vec![1, 2, 3]), + PresharedKeyBinder::new(vec![3, 4, 5]), + ] + }), ClientExtension::Cookie(PayloadU16(vec![1, 2, 3])), ClientExtension::ExtendedMasterSecretRequest, ClientExtension::CertificateStatusRequest(CertificateStatusRequest::build_ocsp()), @@ -409,6 +415,44 @@ fn client_has_duplicate_extensions_works() { assert!(!chp.has_duplicate_extension()); } +#[test] +fn test_truncated_psk_offer() { + let ext = ClientExtension::PresharedKey(PresharedKeyOffer { + identities: vec![ + PresharedKeyIdentity::new(vec![3, 4, 5], 123456), + ], + binders: vec![ + PresharedKeyBinder::new(vec![1, 2, 3]), + ] + }); + + let mut enc = ext.get_encoding(); + println!("testing {:?} enc {:?}", ext, enc); + for l in 0..enc.len() { + if l == 9 { + continue; + } + put_u16(l as u16, &mut enc[4..]); + let rc = ClientExtension::read_bytes(&enc); + assert!(rc.is_none()); + } +} + +#[test] +fn test_truncated_client_hello_is_detected() { + let ch = get_sample_clienthellopayload(); + let enc = ch.get_encoding(); + println!("testing {:?} enc {:?}", ch, enc); + + for l in 0..enc.len() { + println!("len {:?} enc {:?}", l, &enc[..l]); + if l == 41 { + continue; // where extensions are empty + } + assert!(ClientHelloPayload::read_bytes(&enc[..l]).is_none()); + } +} + #[test] fn test_truncated_client_extension_is_detected() { let chp = get_sample_clienthellopayload(); @@ -851,9 +895,8 @@ fn get_sample_certificatestatus() -> CertificateStatus { } } -#[test] -fn can_roundtrip_all_tls12_handshake_payloads() { - let hms = [ +fn get_all_tls12_handshake_payloads() -> Vec { + vec![ HandshakeMessagePayload { typ: HandshakeType::HelloRequest, payload: HandshakePayload::HelloRequest, @@ -922,13 +965,16 @@ fn can_roundtrip_all_tls12_handshake_payloads() { typ: HandshakeType::Unknown(99), payload: HandshakePayload::Unknown(Payload(vec![ 1, 2, 3 ])), }, - ]; + ] +} - for ref hm in hms.iter() { +#[test] +fn can_roundtrip_all_tls12_handshake_payloads() { + for ref hm in get_all_tls12_handshake_payloads().iter() { println!("{:?}", hm.typ); let bytes = hm.get_encoding(); let mut rd = Reader::init(&bytes); - let other = HandshakeMessagePayload::read_version(&mut rd, ProtocolVersion::TLSv1_2) + let other = HandshakeMessagePayload::read(&mut rd) .unwrap(); assert_eq!(rd.any_left(), false); assert_eq!(hm.get_encoding(), other.get_encoding()); @@ -939,8 +985,39 @@ fn can_roundtrip_all_tls12_handshake_payloads() { } #[test] -fn can_roundtrip_all_tls13_handshake_payloads() { - let hms = [ +fn can_detect_truncation_of_all_tls12_handshake_payloads() { + for hm in get_all_tls12_handshake_payloads().iter() { + let mut enc = hm.get_encoding(); + println!("test {:?} enc {:?}", hm, enc); + + // outer truncation + for l in 0..enc.len() { + assert!(HandshakeMessagePayload::read_bytes(&enc[..l]).is_none()) + } + + // inner truncation + for l in 0..enc.len()-4 { + put_u24(l as u32, &mut enc[1..]); + println!(" check len {:?} enc {:?}", l, enc); + + match (hm.typ, l) { + (HandshakeType::ClientHello, 41) | + (HandshakeType::ServerHello, 38) | + (HandshakeType::ServerKeyExchange, _) | + (HandshakeType::ClientKeyExchange, _) | + (HandshakeType::Finished, _) | + (HandshakeType::Unknown(_), _) => continue, + _ => {} + }; + + assert!(HandshakeMessagePayload::read_version(&mut Reader::init(&enc), ProtocolVersion::TLSv1_2).is_none()); + assert!(HandshakeMessagePayload::read_bytes(&enc).is_none()); + } + } +} + +fn get_all_tls13_handshake_payloads() -> Vec { + vec![ HandshakeMessagePayload { typ: HandshakeType::HelloRequest, payload: HandshakePayload::HelloRequest, @@ -973,6 +1050,11 @@ fn can_roundtrip_all_tls13_handshake_payloads() { typ: HandshakeType::CertificateRequest, payload: HandshakePayload::CertificateRequestTLS13(get_sample_certificaterequestpayloadtls13()), }, + HandshakeMessagePayload { + typ: HandshakeType::CertificateVerify, + payload: HandshakePayload::CertificateVerify(DigitallySignedStruct::new(SignatureScheme::ECDSA_NISTP256_SHA256, + vec![ 1, 2, 3 ])) + }, HandshakeMessagePayload { typ: HandshakeType::ServerHelloDone, payload: HandshakePayload::ServerHelloDone, @@ -1009,9 +1091,12 @@ fn can_roundtrip_all_tls13_handshake_payloads() { typ: HandshakeType::Unknown(99), payload: HandshakePayload::Unknown(Payload(vec![ 1, 2, 3 ])), }, - ]; + ] +} - for ref hm in hms.iter() { +#[test] +fn can_roundtrip_all_tls13_handshake_payloads() { + for ref hm in get_all_tls13_handshake_payloads().iter() { println!("{:?}", hm.typ); let bytes = hm.get_encoding(); let mut rd = Reader::init(&bytes); @@ -1025,3 +1110,51 @@ fn can_roundtrip_all_tls13_handshake_payloads() { println!("{:?}", other); } } + +fn put_u24(u: u32, b: &mut [u8]) { + b[0] = (u >> 16) as u8; + b[1] = (u >> 8) as u8; + b[2] = u as u8; +} + +#[test] +fn can_detect_truncation_of_all_tls13_handshake_payloads() { + for hm in get_all_tls13_handshake_payloads().iter() { + let mut enc = hm.get_encoding(); + println!("test {:?} enc {:?}", hm, enc); + + // outer truncation + for l in 0..enc.len() { + assert!(HandshakeMessagePayload::read_bytes(&enc[..l]).is_none()) + } + + // inner truncation + for l in 0..enc.len()-4 { + put_u24(l as u32, &mut enc[1..]); + println!(" check len {:?} enc {:?}", l, enc); + + match (hm.typ, l) { + (HandshakeType::ClientHello, 41) | + (HandshakeType::ServerHello, 38) | + (HandshakeType::ServerKeyExchange, _) | + (HandshakeType::ClientKeyExchange, _) | + (HandshakeType::Finished, _) | + (HandshakeType::Unknown(_), _) => continue, + _ => {} + }; + + assert!(HandshakeMessagePayload::read_version(&mut Reader::init(&enc), ProtocolVersion::TLSv1_3).is_none()); + } + } +} + +#[test] +fn cannot_read_messagehash_from_network() { + let mh = HandshakeMessagePayload { + typ: HandshakeType::MessageHash, + payload: HandshakePayload::MessageHash(Payload::new(vec![ 1, 2, 3 ])), + }; + println!("mh {:?}", mh); + let enc = mh.get_encoding(); + assert!(HandshakeMessagePayload::read_bytes(&enc).is_none()); +}