test app-data transmission

This commit is contained in:
Jorge Aparicio 2023-11-13 17:36:43 +01:00 committed by Joe Birr-Pixton
parent aec738ffd4
commit 3eacd0cb05
1 changed files with 235 additions and 23 deletions

View File

@ -3,66 +3,228 @@ use std::sync::Arc;
use rustls::client::{ClientConnectionData, UnbufferedClientConnection};
use rustls::server::{ServerConnectionData, UnbufferedServerConnection};
use rustls::unbuffered::{ConnectionState, UnbufferedConnectionCommon, UnbufferedStatus};
use rustls::unbuffered::{
ConnectionState, WriteTraffic, UnbufferedConnectionCommon, UnbufferedStatus,
};
use crate::common::*;
mod common;
const MAX_ITERATIONS: usize = 100;
#[test]
fn handshake() {
for version in rustls::ALL_VERSIONS {
let server_config = make_server_config(KeyType::Rsa);
let client_config = make_client_config_with_versions(KeyType::Rsa, &[version]);
let mut client =
UnbufferedClientConnection::new(Arc::new(client_config), server_name("localhost"))
.unwrap();
let mut server = UnbufferedServerConnection::new(Arc::new(server_config)).unwrap();
let (mut client, mut server) = make_connection_pair(version);
let mut buffers = BothBuffers::default();
let mut count = 0;
let mut client_handshake_done = false;
let mut server_handshake_done = false;
while !client_handshake_done || !server_handshake_done {
match advance_client(&mut client, &mut buffers.client) {
match advance_client(&mut client, &mut buffers.client, NO_ACTIONS) {
State::EncodedTlsData => {}
State::TransmitTlsData => buffers.client_send(),
State::TransmitTlsData {
sent_app_data: false,
} => buffers.client_send(),
State::BlockedHandshake => buffers.server_send(),
State::WriteTraffic => client_handshake_done = true,
State::WriteTraffic {
sent_app_data: false,
} => client_handshake_done = true,
state => unreachable!("{state:?}"),
}
match advance_server(&mut server, &mut buffers.server) {
match advance_server(&mut server, &mut buffers.server, NO_ACTIONS) {
State::EncodedTlsData => {}
State::TransmitTlsData => buffers.server_send(),
State::TransmitTlsData {
sent_app_data: false,
} => buffers.server_send(),
State::BlockedHandshake => buffers.client_send(),
State::WriteTraffic => server_handshake_done = true,
State::WriteTraffic {
sent_app_data: false,
} => server_handshake_done = true,
state => unreachable!("{state:?}"),
}
count += 1;
assert!(count <= 100, "handshake {version:?} was not completed");
assert!(
count <= MAX_ITERATIONS,
"handshake {version:?} was not completed"
);
}
}
}
#[test]
fn app_data_client_to_server() {
let expected: &[_] = b"hello";
for version in rustls::ALL_VERSIONS {
eprintln!("{version:?}");
let (mut client, mut server) = make_connection_pair(version);
let mut buffers = BothBuffers::default();
let mut client_actions = Actions {
app_data_to_send: Some(expected),
};
let mut received_app_data = vec![];
let mut count = 0;
let mut client_handshake_done = false;
let mut server_handshake_done = false;
while !client_handshake_done || !server_handshake_done {
match advance_client(&mut client, &mut buffers.client, client_actions) {
State::EncodedTlsData => {}
State::TransmitTlsData { sent_app_data } => {
buffers.client_send();
if sent_app_data {
client_actions.app_data_to_send = None;
}
}
State::BlockedHandshake => buffers.server_send(),
State::WriteTraffic { sent_app_data } => {
if sent_app_data {
buffers.client_send();
client_actions.app_data_to_send = None;
}
client_handshake_done = true
}
state => unreachable!("{state:?}"),
}
match advance_server(&mut server, &mut buffers.server, NO_ACTIONS) {
State::EncodedTlsData => {}
State::TransmitTlsData {
sent_app_data: false,
} => buffers.server_send(),
State::BlockedHandshake => buffers.client_send(),
State::ReceivedAppData { records } => {
received_app_data.extend(records);
}
State::WriteTraffic {
sent_app_data: false,
} => server_handshake_done = true,
state => unreachable!("{state:?}"),
}
count += 1;
assert!(
count <= MAX_ITERATIONS,
"handshake {version:?} was not completed"
);
}
assert!(client_handshake_done);
assert!(server_handshake_done);
assert!(client_actions
.app_data_to_send
.is_none());
assert_eq!([expected], received_app_data.as_slice());
}
}
#[test]
fn app_data_server_to_client() {
let expected: &[_] = b"hello";
for version in rustls::ALL_VERSIONS {
eprintln!("{version:?}");
let (mut client, mut server) = make_connection_pair(version);
let mut buffers = BothBuffers::default();
let mut server_actions = Actions {
app_data_to_send: Some(expected),
};
let mut received_app_data = vec![];
let mut count = 0;
let mut client_handshake_done = false;
let mut server_handshake_done = false;
while !client_handshake_done || !server_handshake_done {
match advance_client(&mut client, &mut buffers.client, NO_ACTIONS) {
State::EncodedTlsData => {}
State::TransmitTlsData {
sent_app_data: false,
} => buffers.client_send(),
State::BlockedHandshake => buffers.server_send(),
State::WriteTraffic {
sent_app_data: false,
} => client_handshake_done = true,
State::ReceivedAppData { records } => {
received_app_data.extend(records);
}
state => unreachable!("{state:?}"),
}
match advance_server(&mut server, &mut buffers.server, server_actions) {
State::EncodedTlsData => {}
State::TransmitTlsData { sent_app_data } => {
buffers.server_send();
if sent_app_data {
server_actions.app_data_to_send = None;
}
}
State::BlockedHandshake => buffers.client_send(),
State::ReceivedAppData { records } => {
received_app_data.extend(records);
}
// server does not need to reach this state to send data to the client
State::WriteTraffic {
sent_app_data: false,
} => server_handshake_done = true,
state => unreachable!("{state:?}"),
}
count += 1;
assert!(
count <= MAX_ITERATIONS,
"handshake {version:?} was not completed"
);
}
assert!(client_handshake_done);
assert!(server_handshake_done);
assert!(server_actions
.app_data_to_send
.is_none());
assert_eq!([expected], received_app_data.as_slice());
}
}
#[derive(Debug)]
enum State {
EncodedTlsData,
TransmitTlsData { sent_app_data: bool },
BlockedHandshake,
WriteTraffic,
TransmitTlsData,
ReceivedAppData { records: Vec<Vec<u8>> },
WriteTraffic { sent_app_data: bool },
}
const NO_ACTIONS: Actions = Actions {
app_data_to_send: None,
};
#[derive(Clone, Copy, Debug)]
struct Actions<'a> {
app_data_to_send: Option<&'a [u8]>,
}
fn advance_client(
conn: &mut UnbufferedConnectionCommon<ClientConnectionData>,
buffers: &mut Buffers,
actions: Actions,
) -> State {
let UnbufferedStatus { discard, state } = conn
.process_tls_records(buffers.incoming.filled())
.unwrap();
let state = handle_state(state, &mut buffers.outgoing);
let state = handle_state(state, &mut buffers.outgoing, actions);
buffers.incoming.discard(discard);
state
@ -71,18 +233,23 @@ fn advance_client(
fn advance_server(
conn: &mut UnbufferedConnectionCommon<ServerConnectionData>,
buffers: &mut Buffers,
actions: Actions,
) -> State {
let UnbufferedStatus { discard, state } = conn
.process_tls_records(buffers.incoming.filled())
.unwrap();
let state = handle_state(state, &mut buffers.outgoing);
let state = handle_state(state, &mut buffers.outgoing, actions);
buffers.incoming.discard(discard);
state
}
fn handle_state<Data>(state: ConnectionState<'_, '_, Data>, outgoing: &mut Buffer) -> State {
fn handle_state<Data>(
state: ConnectionState<'_, '_, Data>,
outgoing: &mut Buffer,
actions: Actions,
) -> State {
match state {
ConnectionState::EncodeTlsData(mut state) => {
let written = state
@ -93,21 +260,54 @@ fn handle_state<Data>(state: ConnectionState<'_, '_, Data>, outgoing: &mut Buffe
State::EncodedTlsData
}
ConnectionState::TransmitTlsData(state) => {
ConnectionState::TransmitTlsData(mut state) => {
let mut sent_app_data = false;
if let Some(app_data) = actions.app_data_to_send {
if let Some(mut state) = state.may_encrypt_app_data() {
encrypt(&mut state, app_data, outgoing);
sent_app_data = true;
}
}
// this should be called *after* the data has been transmitted but it's easier to
// do it in reverse
state.done();
State::TransmitTlsData
State::TransmitTlsData { sent_app_data }
}
ConnectionState::BlockedHandshake { .. } => State::BlockedHandshake,
ConnectionState::WriteTraffic(_) => State::WriteTraffic,
ConnectionState::WriteTraffic(mut state) => {
let mut sent_app_data = false;
if let Some(app_data) = actions.app_data_to_send {
encrypt(&mut state, app_data, outgoing);
sent_app_data = true;
}
State::WriteTraffic { sent_app_data }
}
ConnectionState::ReadTraffic(mut state) => {
let mut records = vec![];
while let Some(res) = state.next_record() {
records.push(res.unwrap().payload.to_vec());
}
State::ReceivedAppData { records }
}
_ => unreachable!(),
}
}
fn encrypt<Data>(state: &mut WriteTraffic<'_, Data>, app_data: &[u8], outgoing: &mut Buffer) {
let written = state
.encrypt(app_data, outgoing.unfilled())
.unwrap();
outgoing.advance(written);
}
#[derive(Default)]
struct BothBuffers {
client: Buffers,
@ -191,3 +391,15 @@ impl Buffer {
&mut self.inner[self.used..]
}
}
fn make_connection_pair(
version: &'static rustls::SupportedProtocolVersion,
) -> (UnbufferedClientConnection, UnbufferedServerConnection) {
let server_config = make_server_config(KeyType::Rsa);
let client_config = make_client_config_with_versions(KeyType::Rsa, &[version]);
let client =
UnbufferedClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap();
let server = UnbufferedServerConnection::new(Arc::new(server_config)).unwrap();
(client, server)
}