Minor refactors in preparation for FFDHE work

This commit is contained in:
Arash Sahebolamri 2024-02-05 11:32:15 -08:00 committed by Joe Birr-Pixton
parent d89d84f655
commit cb91090a60
5 changed files with 62 additions and 61 deletions

View File

@ -444,7 +444,7 @@ impl State<ClientConnectionData> for ExpectServerKx<'_> {
)?;
self.transcript.add_message(&m);
let ecdhe = opaque_kx
let kx = opaque_kx
.unwrap_given_kxa(self.suite.kx)
.ok_or_else(|| {
cx.common.send_fatal_alert(
@ -455,12 +455,12 @@ impl State<ClientConnectionData> for ExpectServerKx<'_> {
// Save the signature and signed parameters for later verification.
let mut kx_params = Vec::new();
ecdhe.params.encode(&mut kx_params);
let server_kx = ServerKxDetails::new(kx_params, ecdhe.dss);
kx.params.encode(&mut kx_params);
let server_kx = ServerKxDetails::new(kx_params, kx.dss);
#[cfg_attr(not(feature = "logging"), allow(unused_variables))]
{
debug!("ECDHE curve is {:?}", ecdhe.params.curve_params);
debug!("ECDHE curve is {:?}", kx.params.curve_params);
}
Ok(Box::new(ExpectServerDoneOrCertReq {
@ -894,7 +894,7 @@ impl State<ClientConnectionData> for ExpectServerDone<'_> {
// 5a.
let ecdh_params =
tls12::decode_ecdh_params::<ServerEcdhParams>(cx.common, &st.server_kx.kx_params)?;
tls12::decode_kx_params::<ServerEcdhParams>(cx.common, &st.server_kx.kx_params)?;
let named_group = ecdh_params.curve_params.named_group;
let skxg = match st.config.find_kx_group(named_group) {
Some(skxg) => skxg,

View File

@ -428,12 +428,12 @@ mod client_hello {
let kx = selected_group
.start()
.map_err(|_| Error::FailedToGetRandomBytes)?;
let secdh = ServerEcdhParams::new(&*kx);
let kx_params = ServerEcdhParams::new(&*kx);
let mut msg = Vec::new();
msg.extend(randoms.client);
msg.extend(randoms.server);
secdh.encode(&mut msg);
kx_params.encode(&mut msg);
let signer = signing_key
.choose_scheme(&sigschemes)
@ -442,7 +442,7 @@ mod client_hello {
let sig = signer.sign(&msg)?;
let skx = ServerKeyExchangePayload::Ecdhe(EcdheServerKeyExchange {
params: secdh,
params: kx_params,
dss: DigitallySignedStruct::new(sigscheme, sig),
});
@ -628,7 +628,7 @@ impl State<ServerConnectionData> for ExpectClientKx<'_> {
// Complete key agreement, and set up encryption with the
// resulting premaster secret.
let peer_kx_params =
tls12::decode_ecdh_params::<ClientEcdhParams>(cx.common, client_kx.bytes())?;
tls12::decode_kx_params::<ClientEcdhParams>(cx.common, client_kx.bytes())?;
let secrets = ConnectionSecrets::from_key_exchange(
self.server_kx,
&peer_kx_params.public.0,

View File

@ -320,7 +320,7 @@ fn join_randoms(first: &[u8; 32], second: &[u8; 32]) -> [u8; 64] {
type MessageCipherPair = (Box<dyn MessageDecrypter>, Box<dyn MessageEncrypter>);
pub(crate) fn decode_ecdh_params<'a, T: Codec<'a>>(
pub(crate) fn decode_kx_params<'a, T: Codec<'a>>(
common: &mut CommonState,
kx_params: &'a [u8],
) -> Result<T, Error> {
@ -353,12 +353,12 @@ mod tests {
server_buf.push(34);
let mut common = CommonState::new(Side::Client);
assert!(decode_ecdh_params::<ServerEcdhParams>(&mut common, &server_buf).is_err());
assert!(decode_kx_params::<ServerEcdhParams>(&mut common, &server_buf).is_err());
}
#[test]
fn client_ecdhe_invalid() {
let mut common = CommonState::new(Side::Server);
assert!(decode_ecdh_params::<ClientEcdhParams>(&mut common, &[34]).is_err());
assert!(decode_kx_params::<ClientEcdhParams>(&mut common, &[34]).is_err());
}
}

View File

@ -2828,55 +2828,6 @@ fn test_tls13_exporter_maximum_output_length() {
);
}
fn do_suite_test(
client_config: ClientConfig,
server_config: ServerConfig,
expect_suite: SupportedCipherSuite,
expect_version: ProtocolVersion,
) {
println!(
"do_suite_test {:?} {:?}",
expect_version,
expect_suite.suite()
);
let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
assert_eq!(None, client.negotiated_cipher_suite());
assert_eq!(None, server.negotiated_cipher_suite());
assert_eq!(None, client.protocol_version());
assert_eq!(None, server.protocol_version());
assert!(client.is_handshaking());
assert!(server.is_handshaking());
transfer(&mut client, &mut server);
server.process_new_packets().unwrap();
assert!(client.is_handshaking());
assert!(server.is_handshaking());
assert_eq!(None, client.protocol_version());
assert_eq!(Some(expect_version), server.protocol_version());
assert_eq!(None, client.negotiated_cipher_suite());
assert_eq!(Some(expect_suite), server.negotiated_cipher_suite());
transfer(&mut server, &mut client);
client.process_new_packets().unwrap();
assert_eq!(Some(expect_suite), client.negotiated_cipher_suite());
assert_eq!(Some(expect_suite), server.negotiated_cipher_suite());
transfer(&mut client, &mut server);
server.process_new_packets().unwrap();
transfer(&mut server, &mut client);
client.process_new_packets().unwrap();
assert!(!client.is_handshaking());
assert!(!server.is_handshaking());
assert_eq!(Some(expect_version), client.protocol_version());
assert_eq!(Some(expect_version), server.protocol_version());
assert_eq!(Some(expect_suite), client.negotiated_cipher_suite());
assert_eq!(Some(expect_suite), server.negotiated_cipher_suite());
}
fn find_suite(suite: CipherSuite) -> SupportedCipherSuite {
for scs in provider::ALL_CIPHER_SUITES
.iter()

View File

@ -17,6 +17,7 @@ use rustls::Error;
use rustls::RootCertStore;
use rustls::{ClientConfig, ClientConnection};
use rustls::{ConnectionCommon, ServerConfig, ServerConnection, SideData};
use rustls::{ProtocolVersion, SupportedCipherSuite};
#[cfg(all(any(not(feature = "ring"), feature = "fips"), feature = "aws_lc_rs"))]
pub use rustls::crypto::aws_lc_rs as provider;
@ -705,3 +706,52 @@ impl io::Read for FailsReads {
Err(io::Error::from(self.errkind))
}
}
pub fn do_suite_test(
client_config: ClientConfig,
server_config: ServerConfig,
expect_suite: SupportedCipherSuite,
expect_version: ProtocolVersion,
) {
println!(
"do_suite_test {:?} {:?}",
expect_version,
expect_suite.suite()
);
let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
assert_eq!(None, client.negotiated_cipher_suite());
assert_eq!(None, server.negotiated_cipher_suite());
assert_eq!(None, client.protocol_version());
assert_eq!(None, server.protocol_version());
assert!(client.is_handshaking());
assert!(server.is_handshaking());
transfer(&mut client, &mut server);
server.process_new_packets().unwrap();
assert!(client.is_handshaking());
assert!(server.is_handshaking());
assert_eq!(None, client.protocol_version());
assert_eq!(Some(expect_version), server.protocol_version());
assert_eq!(None, client.negotiated_cipher_suite());
assert_eq!(Some(expect_suite), server.negotiated_cipher_suite());
transfer(&mut server, &mut client);
client.process_new_packets().unwrap();
assert_eq!(Some(expect_suite), client.negotiated_cipher_suite());
assert_eq!(Some(expect_suite), server.negotiated_cipher_suite());
transfer(&mut client, &mut server);
server.process_new_packets().unwrap();
transfer(&mut server, &mut client);
client.process_new_packets().unwrap();
assert!(!client.is_handshaking());
assert!(!server.is_handshaking());
assert_eq!(Some(expect_version), client.protocol_version());
assert_eq!(Some(expect_version), server.protocol_version());
assert_eq!(Some(expect_suite), client.negotiated_cipher_suite());
assert_eq!(Some(expect_suite), server.negotiated_cipher_suite());
}