mirror of https://github.com/ctz/rustls
test app-data transmission
This commit is contained in:
parent
aec738ffd4
commit
3eacd0cb05
|
@ -3,66 +3,228 @@ use std::sync::Arc;
|
|||
|
||||
use rustls::client::{ClientConnectionData, UnbufferedClientConnection};
|
||||
use rustls::server::{ServerConnectionData, UnbufferedServerConnection};
|
||||
use rustls::unbuffered::{ConnectionState, UnbufferedConnectionCommon, UnbufferedStatus};
|
||||
use rustls::unbuffered::{
|
||||
ConnectionState, WriteTraffic, UnbufferedConnectionCommon, UnbufferedStatus,
|
||||
};
|
||||
|
||||
use crate::common::*;
|
||||
|
||||
mod common;
|
||||
|
||||
const MAX_ITERATIONS: usize = 100;
|
||||
|
||||
#[test]
|
||||
fn handshake() {
|
||||
for version in rustls::ALL_VERSIONS {
|
||||
let server_config = make_server_config(KeyType::Rsa);
|
||||
let client_config = make_client_config_with_versions(KeyType::Rsa, &[version]);
|
||||
|
||||
let mut client =
|
||||
UnbufferedClientConnection::new(Arc::new(client_config), server_name("localhost"))
|
||||
.unwrap();
|
||||
let mut server = UnbufferedServerConnection::new(Arc::new(server_config)).unwrap();
|
||||
let (mut client, mut server) = make_connection_pair(version);
|
||||
let mut buffers = BothBuffers::default();
|
||||
|
||||
let mut count = 0;
|
||||
let mut client_handshake_done = false;
|
||||
let mut server_handshake_done = false;
|
||||
while !client_handshake_done || !server_handshake_done {
|
||||
match advance_client(&mut client, &mut buffers.client) {
|
||||
match advance_client(&mut client, &mut buffers.client, NO_ACTIONS) {
|
||||
State::EncodedTlsData => {}
|
||||
State::TransmitTlsData => buffers.client_send(),
|
||||
State::TransmitTlsData {
|
||||
sent_app_data: false,
|
||||
} => buffers.client_send(),
|
||||
State::BlockedHandshake => buffers.server_send(),
|
||||
State::WriteTraffic => client_handshake_done = true,
|
||||
State::WriteTraffic {
|
||||
sent_app_data: false,
|
||||
} => client_handshake_done = true,
|
||||
state => unreachable!("{state:?}"),
|
||||
}
|
||||
|
||||
match advance_server(&mut server, &mut buffers.server) {
|
||||
match advance_server(&mut server, &mut buffers.server, NO_ACTIONS) {
|
||||
State::EncodedTlsData => {}
|
||||
State::TransmitTlsData => buffers.server_send(),
|
||||
State::TransmitTlsData {
|
||||
sent_app_data: false,
|
||||
} => buffers.server_send(),
|
||||
State::BlockedHandshake => buffers.client_send(),
|
||||
State::WriteTraffic => server_handshake_done = true,
|
||||
State::WriteTraffic {
|
||||
sent_app_data: false,
|
||||
} => server_handshake_done = true,
|
||||
state => unreachable!("{state:?}"),
|
||||
}
|
||||
|
||||
count += 1;
|
||||
|
||||
assert!(count <= 100, "handshake {version:?} was not completed");
|
||||
assert!(
|
||||
count <= MAX_ITERATIONS,
|
||||
"handshake {version:?} was not completed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_data_client_to_server() {
|
||||
let expected: &[_] = b"hello";
|
||||
for version in rustls::ALL_VERSIONS {
|
||||
eprintln!("{version:?}");
|
||||
|
||||
let (mut client, mut server) = make_connection_pair(version);
|
||||
let mut buffers = BothBuffers::default();
|
||||
|
||||
let mut client_actions = Actions {
|
||||
app_data_to_send: Some(expected),
|
||||
};
|
||||
let mut received_app_data = vec![];
|
||||
let mut count = 0;
|
||||
let mut client_handshake_done = false;
|
||||
let mut server_handshake_done = false;
|
||||
while !client_handshake_done || !server_handshake_done {
|
||||
match advance_client(&mut client, &mut buffers.client, client_actions) {
|
||||
State::EncodedTlsData => {}
|
||||
State::TransmitTlsData { sent_app_data } => {
|
||||
buffers.client_send();
|
||||
|
||||
if sent_app_data {
|
||||
client_actions.app_data_to_send = None;
|
||||
}
|
||||
}
|
||||
State::BlockedHandshake => buffers.server_send(),
|
||||
State::WriteTraffic { sent_app_data } => {
|
||||
if sent_app_data {
|
||||
buffers.client_send();
|
||||
client_actions.app_data_to_send = None;
|
||||
}
|
||||
|
||||
client_handshake_done = true
|
||||
}
|
||||
state => unreachable!("{state:?}"),
|
||||
}
|
||||
|
||||
match advance_server(&mut server, &mut buffers.server, NO_ACTIONS) {
|
||||
State::EncodedTlsData => {}
|
||||
State::TransmitTlsData {
|
||||
sent_app_data: false,
|
||||
} => buffers.server_send(),
|
||||
State::BlockedHandshake => buffers.client_send(),
|
||||
State::ReceivedAppData { records } => {
|
||||
received_app_data.extend(records);
|
||||
}
|
||||
State::WriteTraffic {
|
||||
sent_app_data: false,
|
||||
} => server_handshake_done = true,
|
||||
state => unreachable!("{state:?}"),
|
||||
}
|
||||
|
||||
count += 1;
|
||||
|
||||
assert!(
|
||||
count <= MAX_ITERATIONS,
|
||||
"handshake {version:?} was not completed"
|
||||
);
|
||||
}
|
||||
|
||||
assert!(client_handshake_done);
|
||||
assert!(server_handshake_done);
|
||||
|
||||
assert!(client_actions
|
||||
.app_data_to_send
|
||||
.is_none());
|
||||
assert_eq!([expected], received_app_data.as_slice());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn app_data_server_to_client() {
|
||||
let expected: &[_] = b"hello";
|
||||
for version in rustls::ALL_VERSIONS {
|
||||
eprintln!("{version:?}");
|
||||
|
||||
let (mut client, mut server) = make_connection_pair(version);
|
||||
let mut buffers = BothBuffers::default();
|
||||
|
||||
let mut server_actions = Actions {
|
||||
app_data_to_send: Some(expected),
|
||||
};
|
||||
let mut received_app_data = vec![];
|
||||
let mut count = 0;
|
||||
let mut client_handshake_done = false;
|
||||
let mut server_handshake_done = false;
|
||||
while !client_handshake_done || !server_handshake_done {
|
||||
match advance_client(&mut client, &mut buffers.client, NO_ACTIONS) {
|
||||
State::EncodedTlsData => {}
|
||||
State::TransmitTlsData {
|
||||
sent_app_data: false,
|
||||
} => buffers.client_send(),
|
||||
State::BlockedHandshake => buffers.server_send(),
|
||||
State::WriteTraffic {
|
||||
sent_app_data: false,
|
||||
} => client_handshake_done = true,
|
||||
State::ReceivedAppData { records } => {
|
||||
received_app_data.extend(records);
|
||||
}
|
||||
state => unreachable!("{state:?}"),
|
||||
}
|
||||
|
||||
match advance_server(&mut server, &mut buffers.server, server_actions) {
|
||||
State::EncodedTlsData => {}
|
||||
State::TransmitTlsData { sent_app_data } => {
|
||||
buffers.server_send();
|
||||
if sent_app_data {
|
||||
server_actions.app_data_to_send = None;
|
||||
}
|
||||
}
|
||||
State::BlockedHandshake => buffers.client_send(),
|
||||
State::ReceivedAppData { records } => {
|
||||
received_app_data.extend(records);
|
||||
}
|
||||
// server does not need to reach this state to send data to the client
|
||||
State::WriteTraffic {
|
||||
sent_app_data: false,
|
||||
} => server_handshake_done = true,
|
||||
state => unreachable!("{state:?}"),
|
||||
}
|
||||
|
||||
count += 1;
|
||||
|
||||
assert!(
|
||||
count <= MAX_ITERATIONS,
|
||||
"handshake {version:?} was not completed"
|
||||
);
|
||||
}
|
||||
|
||||
assert!(client_handshake_done);
|
||||
assert!(server_handshake_done);
|
||||
|
||||
assert!(server_actions
|
||||
.app_data_to_send
|
||||
.is_none());
|
||||
assert_eq!([expected], received_app_data.as_slice());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum State {
|
||||
EncodedTlsData,
|
||||
TransmitTlsData { sent_app_data: bool },
|
||||
BlockedHandshake,
|
||||
WriteTraffic,
|
||||
TransmitTlsData,
|
||||
ReceivedAppData { records: Vec<Vec<u8>> },
|
||||
WriteTraffic { sent_app_data: bool },
|
||||
}
|
||||
|
||||
const NO_ACTIONS: Actions = Actions {
|
||||
app_data_to_send: None,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct Actions<'a> {
|
||||
app_data_to_send: Option<&'a [u8]>,
|
||||
}
|
||||
|
||||
fn advance_client(
|
||||
conn: &mut UnbufferedConnectionCommon<ClientConnectionData>,
|
||||
buffers: &mut Buffers,
|
||||
actions: Actions,
|
||||
) -> State {
|
||||
let UnbufferedStatus { discard, state } = conn
|
||||
.process_tls_records(buffers.incoming.filled())
|
||||
.unwrap();
|
||||
|
||||
let state = handle_state(state, &mut buffers.outgoing);
|
||||
let state = handle_state(state, &mut buffers.outgoing, actions);
|
||||
buffers.incoming.discard(discard);
|
||||
|
||||
state
|
||||
|
@ -71,18 +233,23 @@ fn advance_client(
|
|||
fn advance_server(
|
||||
conn: &mut UnbufferedConnectionCommon<ServerConnectionData>,
|
||||
buffers: &mut Buffers,
|
||||
actions: Actions,
|
||||
) -> State {
|
||||
let UnbufferedStatus { discard, state } = conn
|
||||
.process_tls_records(buffers.incoming.filled())
|
||||
.unwrap();
|
||||
|
||||
let state = handle_state(state, &mut buffers.outgoing);
|
||||
let state = handle_state(state, &mut buffers.outgoing, actions);
|
||||
buffers.incoming.discard(discard);
|
||||
|
||||
state
|
||||
}
|
||||
|
||||
fn handle_state<Data>(state: ConnectionState<'_, '_, Data>, outgoing: &mut Buffer) -> State {
|
||||
fn handle_state<Data>(
|
||||
state: ConnectionState<'_, '_, Data>,
|
||||
outgoing: &mut Buffer,
|
||||
actions: Actions,
|
||||
) -> State {
|
||||
match state {
|
||||
ConnectionState::EncodeTlsData(mut state) => {
|
||||
let written = state
|
||||
|
@ -93,21 +260,54 @@ fn handle_state<Data>(state: ConnectionState<'_, '_, Data>, outgoing: &mut Buffe
|
|||
State::EncodedTlsData
|
||||
}
|
||||
|
||||
ConnectionState::TransmitTlsData(state) => {
|
||||
ConnectionState::TransmitTlsData(mut state) => {
|
||||
let mut sent_app_data = false;
|
||||
if let Some(app_data) = actions.app_data_to_send {
|
||||
if let Some(mut state) = state.may_encrypt_app_data() {
|
||||
encrypt(&mut state, app_data, outgoing);
|
||||
sent_app_data = true;
|
||||
}
|
||||
}
|
||||
|
||||
// this should be called *after* the data has been transmitted but it's easier to
|
||||
// do it in reverse
|
||||
state.done();
|
||||
State::TransmitTlsData
|
||||
State::TransmitTlsData { sent_app_data }
|
||||
}
|
||||
|
||||
ConnectionState::BlockedHandshake { .. } => State::BlockedHandshake,
|
||||
|
||||
ConnectionState::WriteTraffic(_) => State::WriteTraffic,
|
||||
ConnectionState::WriteTraffic(mut state) => {
|
||||
let mut sent_app_data = false;
|
||||
if let Some(app_data) = actions.app_data_to_send {
|
||||
encrypt(&mut state, app_data, outgoing);
|
||||
sent_app_data = true;
|
||||
}
|
||||
|
||||
State::WriteTraffic { sent_app_data }
|
||||
}
|
||||
|
||||
ConnectionState::ReadTraffic(mut state) => {
|
||||
let mut records = vec![];
|
||||
|
||||
while let Some(res) = state.next_record() {
|
||||
records.push(res.unwrap().payload.to_vec());
|
||||
}
|
||||
|
||||
State::ReceivedAppData { records }
|
||||
}
|
||||
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn encrypt<Data>(state: &mut WriteTraffic<'_, Data>, app_data: &[u8], outgoing: &mut Buffer) {
|
||||
let written = state
|
||||
.encrypt(app_data, outgoing.unfilled())
|
||||
.unwrap();
|
||||
outgoing.advance(written);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BothBuffers {
|
||||
client: Buffers,
|
||||
|
@ -191,3 +391,15 @@ impl Buffer {
|
|||
&mut self.inner[self.used..]
|
||||
}
|
||||
}
|
||||
|
||||
fn make_connection_pair(
|
||||
version: &'static rustls::SupportedProtocolVersion,
|
||||
) -> (UnbufferedClientConnection, UnbufferedServerConnection) {
|
||||
let server_config = make_server_config(KeyType::Rsa);
|
||||
let client_config = make_client_config_with_versions(KeyType::Rsa, &[version]);
|
||||
|
||||
let client =
|
||||
UnbufferedClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap();
|
||||
let server = UnbufferedServerConnection::new(Arc::new(server_config)).unwrap();
|
||||
(client, server)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue