Add `permessage-deflate` support
This commit is contained in:
parent
e1033afd95
commit
edb2377540
|
@ -1,2 +1,4 @@
|
|||
target
|
||||
Cargo.lock
|
||||
autobahn/client/
|
||||
autobahn/server/
|
||||
|
|
|
@ -10,5 +10,6 @@ before_script:
|
|||
|
||||
script:
|
||||
- cargo test --release
|
||||
- cargo test --release --features=deflate
|
||||
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
|
||||
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh
|
||||
|
|
14
Cargo.toml
14
Cargo.toml
|
@ -25,6 +25,15 @@ native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
|
|||
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
|
||||
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
|
||||
__rustls-tls = ["rustls", "webpki"]
|
||||
deflate = ["flate2"]
|
||||
|
||||
[[example]]
|
||||
name = "autobahn-client"
|
||||
required-features = ["deflate"]
|
||||
|
||||
[[example]]
|
||||
name = "autobahn-server"
|
||||
required-features = ["deflate"]
|
||||
|
||||
[dependencies]
|
||||
data-encoding = { version = "2", optional = true }
|
||||
|
@ -38,6 +47,11 @@ sha1 = { version = "0.10", optional = true }
|
|||
thiserror = "1.0.23"
|
||||
url = { version = "2.1.0", optional = true }
|
||||
utf-8 = "0.7.5"
|
||||
headers = { git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" }
|
||||
|
||||
[dependencies.flate2]
|
||||
optional = true
|
||||
version = "1.0"
|
||||
|
||||
[dependencies.native-tls-crate]
|
||||
optional = true
|
||||
|
|
|
@ -72,8 +72,6 @@ Choose the one that is appropriate for your needs.
|
|||
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
|
||||
otherwise you won't be able to communicate with the TLS endpoints.
|
||||
|
||||
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
|
||||
|
||||
Testing
|
||||
-------
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,7 +1,10 @@
|
|||
use log::*;
|
||||
use url::Url;
|
||||
|
||||
use tungstenite::{connect, Error, Message, Result};
|
||||
use tungstenite::{
|
||||
client::connect_with_config, connect, extensions::DeflateConfig, protocol::WebSocketConfig,
|
||||
Error, Message, Result,
|
||||
};
|
||||
|
||||
const AGENT: &str = "Tungstenite";
|
||||
|
||||
|
@ -24,7 +27,14 @@ fn run_test(case: u32) -> Result<()> {
|
|||
info!("Running test case {}", case);
|
||||
let case_url =
|
||||
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
|
||||
let (mut socket, _) = connect(case_url)?;
|
||||
let (mut socket, _) = connect_with_config(
|
||||
case_url,
|
||||
Some(WebSocketConfig {
|
||||
compression: Some(DeflateConfig::default()),
|
||||
..WebSocketConfig::default()
|
||||
}),
|
||||
3,
|
||||
)?;
|
||||
loop {
|
||||
match socket.read_message()? {
|
||||
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
|
||||
|
|
|
@ -4,7 +4,10 @@ use std::{
|
|||
};
|
||||
|
||||
use log::*;
|
||||
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
|
||||
use tungstenite::{
|
||||
accept_with_config, extensions::DeflateConfig, handshake::HandshakeRole,
|
||||
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
|
||||
};
|
||||
|
||||
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
|
||||
match err {
|
||||
|
@ -14,7 +17,14 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
|
|||
}
|
||||
|
||||
fn handle_client(stream: TcpStream) -> Result<()> {
|
||||
let mut socket = accept(stream).map_err(must_not_block)?;
|
||||
let mut socket = accept_with_config(
|
||||
stream,
|
||||
Some(WebSocketConfig {
|
||||
compression: Some(DeflateConfig::default()),
|
||||
..WebSocketConfig::default()
|
||||
}),
|
||||
)
|
||||
.map_err(must_not_block)?;
|
||||
info!("Running test");
|
||||
loop {
|
||||
match socket.read_message()? {
|
||||
|
|
|
@ -35,6 +35,8 @@ fn main() {
|
|||
// rare cases where it is necessary to integrate with existing/legacy
|
||||
// clients which are sending unmasked frames
|
||||
accept_unmasked_frames: true,
|
||||
#[cfg(feature = "deflate")]
|
||||
compression: None,
|
||||
});
|
||||
|
||||
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();
|
||||
|
|
|
@ -32,5 +32,5 @@ docker run -d --rm \
|
|||
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
|
||||
|
||||
sleep 3
|
||||
cargo run --release --example autobahn-client
|
||||
cargo run --release --example autobahn-client --features=deflate
|
||||
test_diff
|
||||
|
|
|
@ -22,7 +22,7 @@ function test_diff() {
|
|||
fi
|
||||
}
|
||||
|
||||
cargo run --release --example autobahn-server & WSSERVER_PID=$!
|
||||
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
|
||||
sleep 3
|
||||
|
||||
docker run --rm \
|
||||
|
|
19
src/error.rs
19
src/error.rs
|
@ -70,6 +70,10 @@ pub enum Error {
|
|||
#[error("HTTP format error: {0}")]
|
||||
#[cfg(feature = "handshake")]
|
||||
HttpFormat(#[from] http::Error),
|
||||
/// Error from `permessage-deflate` extension.
|
||||
#[cfg(feature = "deflate")]
|
||||
#[error("Deflate error: {0}")]
|
||||
Deflate(#[from] crate::extensions::DeflateError),
|
||||
}
|
||||
|
||||
impl From<str::Utf8Error> for Error {
|
||||
|
@ -206,6 +210,9 @@ pub enum ProtocolError {
|
|||
/// Control frames must not be fragmented.
|
||||
#[error("Fragmented control frame")]
|
||||
FragmentedControlFrame,
|
||||
/// Control frames must not be compressed.
|
||||
#[error("Compressed control frame")]
|
||||
CompressedControlFrame,
|
||||
/// Control frames must have a payload of 125 bytes or less.
|
||||
#[error("Control frame too big (payload must be 125 bytes or less)")]
|
||||
ControlFrameTooBig,
|
||||
|
@ -218,6 +225,9 @@ pub enum ProtocolError {
|
|||
/// Received a continue frame despite there being nothing to continue.
|
||||
#[error("Continue frame but nothing to continue")]
|
||||
UnexpectedContinueFrame,
|
||||
/// Received a compressed continue frame.
|
||||
#[error("Continue frame must not have compress bit set")]
|
||||
CompressedContinueFrame,
|
||||
/// Received data while waiting for more fragments.
|
||||
#[error("While waiting for more fragments received: {0}")]
|
||||
ExpectedFragment(Data),
|
||||
|
@ -230,6 +240,15 @@ pub enum ProtocolError {
|
|||
/// The payload for the closing frame is invalid.
|
||||
#[error("Invalid close sequence")]
|
||||
InvalidCloseSequence,
|
||||
/// The negotiation response included an extension not offered.
|
||||
#[error("Extension negotiation response had invalid extension: {0}")]
|
||||
InvalidExtension(String),
|
||||
/// The negotiation response included an extension more than once.
|
||||
#[error("Extension negotiation response had conflicting extension: {0}")]
|
||||
ExtensionConflict(String),
|
||||
/// The `Sec-WebSocket-Extensions` header is invalid.
|
||||
#[error("Invalid \"Sec-WebSocket-Extensions\" header")]
|
||||
InvalidExtensionsHeader,
|
||||
}
|
||||
|
||||
/// Indicates the specific type/cause of URL error.
|
||||
|
|
|
@ -0,0 +1,442 @@
|
|||
use std::convert::TryFrom;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
|
||||
use headers::WebsocketExtension;
|
||||
use http::HeaderValue;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::protocol::Role;
|
||||
|
||||
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
|
||||
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
|
||||
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
|
||||
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
|
||||
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
|
||||
|
||||
const TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
|
||||
|
||||
/// Errors from `permessage-deflate` extension.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DeflateError {
|
||||
/// Compress failed
|
||||
#[error("Failed to compress")]
|
||||
Compress(#[source] std::io::Error),
|
||||
/// Decompress failed
|
||||
#[error("Failed to decompress")]
|
||||
Decompress(#[source] std::io::Error),
|
||||
|
||||
/// Extension negotiation failed.
|
||||
#[error("Extension negotiation failed")]
|
||||
Negotiation(#[source] NegotiationError),
|
||||
}
|
||||
|
||||
/// Errors from `permessage-deflate` extension negotiation.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NegotiationError {
|
||||
/// Unknown parameter in a negotiation response.
|
||||
#[error("Unknown parameter in a negotiation response: {0}")]
|
||||
UnknownParameter(String),
|
||||
/// Duplicate parameter in a negotiation response.
|
||||
#[error("Duplicate parameter in a negotiation response: {0}")]
|
||||
DuplicateParameter(String),
|
||||
/// Received `client_max_window_bits` in a negotiation response for an offer without it.
|
||||
#[error("Received client_max_window_bits in a negotiation response for an offer without it")]
|
||||
UnexpectedClientMaxWindowBits,
|
||||
/// Received unsupported `server_max_window_bits` in a negotiation response.
|
||||
#[error("Received unsupported server_max_window_bits in a negotiation response")]
|
||||
ServerMaxWindowBitsNotSupported,
|
||||
/// Invalid `client_max_window_bits` value in a negotiation response.
|
||||
#[error("Invalid client_max_window_bits value in a negotiation response: {0}")]
|
||||
InvalidClientMaxWindowBitsValue(String),
|
||||
/// Invalid `server_max_window_bits` value in a negotiation response.
|
||||
#[error("Invalid server_max_window_bits value in a negotiation response: {0}")]
|
||||
InvalidServerMaxWindowBitsValue(String),
|
||||
/// Missing `server_max_window_bits` value in a negotiation response.
|
||||
#[error("Missing server_max_window_bits value in a negotiation response")]
|
||||
MissingServerMaxWindowBitsValue,
|
||||
}
|
||||
|
||||
// Parameters `server_max_window_bits` and `client_max_window_bits` are not supported for now
|
||||
// because custom window size requires `flate2/zlib` feature.
|
||||
/// Configurations for `permessage-deflate` Per-Message Compression Extension.
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct DeflateConfig {
|
||||
/// Compression level.
|
||||
pub compression: Compression,
|
||||
/// Request the peer server not to use context takeover.
|
||||
pub server_no_context_takeover: bool,
|
||||
/// Hint that context takeover is not used.
|
||||
pub client_no_context_takeover: bool,
|
||||
}
|
||||
|
||||
impl DeflateConfig {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
PER_MESSAGE_DEFLATE
|
||||
}
|
||||
|
||||
/// Value for `Sec-WebSocket-Extensions` request header.
|
||||
pub(crate) fn generate_offer(&self) -> WebsocketExtension {
|
||||
let mut offers = Vec::new();
|
||||
if self.server_no_context_takeover {
|
||||
offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
|
||||
}
|
||||
|
||||
// > a client informs the peer server of a hint that even if the server doesn't include the
|
||||
// > "client_no_context_takeover" extension parameter in the corresponding
|
||||
// > extension negotiation response to the offer, the client is not going
|
||||
// > to use context takeover.
|
||||
// > https://www.rfc-editor.org/rfc/rfc7692#section-7.1.1.2
|
||||
if self.client_no_context_takeover {
|
||||
offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
|
||||
}
|
||||
to_header_value(&offers)
|
||||
}
|
||||
|
||||
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
|
||||
pub(crate) fn accept_offer(
|
||||
&self,
|
||||
offers: &headers::SecWebsocketExtensions,
|
||||
) -> Option<(WebsocketExtension, DeflateContext)> {
|
||||
// Accept the first valid offer for `permessage-deflate`.
|
||||
// A server MUST decline an extension negotiation offer for this
|
||||
// extension if any of the following conditions are met:
|
||||
// 1. The negotiation offer contains an extension parameter not defined for use in an offer.
|
||||
// 2. The negotiation offer contains an extension parameter with an invalid value.
|
||||
// 3. The negotiation offer contains multiple extension parameters with the same name.
|
||||
// 4. The server doesn't support the offered configuration.
|
||||
offers.iter().find_map(|extension| {
|
||||
if let Some(params) = (extension.name() == self.name()).then(|| extension.params()) {
|
||||
let mut config =
|
||||
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
|
||||
let mut agreed = Vec::new();
|
||||
let mut seen_server_no_context_takeover = false;
|
||||
let mut seen_client_no_context_takeover = false;
|
||||
let mut seen_client_max_window_bits = false;
|
||||
for (key, val) in params {
|
||||
match key {
|
||||
SERVER_NO_CONTEXT_TAKEOVER => {
|
||||
// Invalid offer with multiple params with same name is declined.
|
||||
if seen_server_no_context_takeover {
|
||||
return None;
|
||||
}
|
||||
seen_server_no_context_takeover = true;
|
||||
config.server_no_context_takeover = true;
|
||||
agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
|
||||
}
|
||||
|
||||
CLIENT_NO_CONTEXT_TAKEOVER => {
|
||||
// Invalid offer with multiple params with same name is declined.
|
||||
if seen_client_no_context_takeover {
|
||||
return None;
|
||||
}
|
||||
seen_client_no_context_takeover = true;
|
||||
config.client_no_context_takeover = true;
|
||||
agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
|
||||
}
|
||||
|
||||
// Max window bits are not supported at the moment.
|
||||
SERVER_MAX_WINDOW_BITS => {
|
||||
// Decline offer with invalid parameter value.
|
||||
// `server_max_window_bits` requires a value in range [8, 15].
|
||||
if let Some(bits) = val {
|
||||
if !is_valid_max_window_bits(bits) {
|
||||
return None;
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
|
||||
// A server declines an extension negotiation offer with this parameter
|
||||
// if the server doesn't support it.
|
||||
return None;
|
||||
}
|
||||
|
||||
// Not supported, but server may ignore and accept the offer.
|
||||
CLIENT_MAX_WINDOW_BITS => {
|
||||
// Decline offer with invalid parameter value.
|
||||
// `client_max_window_bits` requires a value in range [8, 15] or no value.
|
||||
if let Some(bits) = val {
|
||||
if !is_valid_max_window_bits(bits) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Invalid offer with multiple params with same name is declined.
|
||||
if seen_client_max_window_bits {
|
||||
return None;
|
||||
}
|
||||
seen_client_max_window_bits = true;
|
||||
}
|
||||
|
||||
// Offer with unknown parameter MUST be declined.
|
||||
_ => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some((to_header_value(&agreed), DeflateContext::new(Role::Server, config)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn accept_response<'a>(
|
||||
&'a self,
|
||||
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
|
||||
) -> Result<DeflateContext, DeflateError> {
|
||||
let mut config = DeflateConfig {
|
||||
compression: self.compression,
|
||||
// If this was hinted in the offer, the client won't use context takeover
|
||||
// even if the response doesn't include it.
|
||||
// See `generate_offer`.
|
||||
client_no_context_takeover: self.client_no_context_takeover,
|
||||
..DeflateConfig::default()
|
||||
};
|
||||
let mut seen_server_no_context_takeover = false;
|
||||
let mut seen_client_no_context_takeover = false;
|
||||
// A client MUST _Fail the WebSocket Connection_ if the peer server
|
||||
// accepted an extension negotiation offer for this extension with an
|
||||
// extension negotiation response meeting any of the following
|
||||
// conditions:
|
||||
// 1. The negotiation response contains an extension parameter not defined for use in a response.
|
||||
// 2. The negotiation response contains an extension parameter with an invalid value.
|
||||
// 3. The negotiation response contains multiple extension parameters with the same name.
|
||||
// 4. The client does not support the configuration that the response represents.
|
||||
for (key, val) in agreed {
|
||||
match key {
|
||||
SERVER_NO_CONTEXT_TAKEOVER => {
|
||||
// Fail the connection when the response contains multiple parameters with the same name.
|
||||
if seen_server_no_context_takeover {
|
||||
return Err(DeflateError::Negotiation(
|
||||
NegotiationError::DuplicateParameter(key.to_owned()),
|
||||
));
|
||||
}
|
||||
seen_server_no_context_takeover = true;
|
||||
// A server MAY include the "server_no_context_takeover" extension
|
||||
// parameter in an extension negotiation response even if the extension
|
||||
// negotiation offer being accepted by the extension negotiation
|
||||
// response didn't include the "server_no_context_takeover" extension
|
||||
// parameter.
|
||||
config.server_no_context_takeover = true;
|
||||
}
|
||||
|
||||
CLIENT_NO_CONTEXT_TAKEOVER => {
|
||||
// Fail the connection when the response contains multiple parameters with the same name.
|
||||
if seen_client_no_context_takeover {
|
||||
return Err(DeflateError::Negotiation(
|
||||
NegotiationError::DuplicateParameter(key.to_owned()),
|
||||
));
|
||||
}
|
||||
seen_client_no_context_takeover = true;
|
||||
// The server may include this parameter in the response and the client MUST support it.
|
||||
config.client_no_context_takeover = true;
|
||||
}
|
||||
|
||||
SERVER_MAX_WINDOW_BITS => {
|
||||
// Fail the connection when the response contains a parameter with invalid value.
|
||||
if let Some(bits) = val {
|
||||
if !is_valid_max_window_bits(bits) {
|
||||
return Err(DeflateError::Negotiation(
|
||||
NegotiationError::InvalidServerMaxWindowBitsValue(bits.to_owned()),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return Err(DeflateError::Negotiation(
|
||||
NegotiationError::MissingServerMaxWindowBitsValue,
|
||||
));
|
||||
}
|
||||
|
||||
// A server may include the "server_max_window_bits" extension parameter
|
||||
// in an extension negotiation response even if the extension
|
||||
// negotiation offer being accepted by the response didn't include the
|
||||
// "server_max_window_bits" extension parameter.
|
||||
//
|
||||
// However, but we need to fail the connection because we don't support it (condition 4).
|
||||
return Err(DeflateError::Negotiation(
|
||||
NegotiationError::ServerMaxWindowBitsNotSupported,
|
||||
));
|
||||
}
|
||||
|
||||
CLIENT_MAX_WINDOW_BITS => {
|
||||
// Fail the connection when the response contains a parameter with invalid value.
|
||||
if let Some(bits) = val {
|
||||
if !is_valid_max_window_bits(bits) {
|
||||
return Err(DeflateError::Negotiation(
|
||||
NegotiationError::InvalidClientMaxWindowBitsValue(bits.to_owned()),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Fail the connection because the parameter is invalid when the client didn't offer.
|
||||
//
|
||||
// If a received extension negotiation offer doesn't have the
|
||||
// "client_max_window_bits" extension parameter, the corresponding
|
||||
// extension negotiation response to the offer MUST NOT include the
|
||||
// "client_max_window_bits" extension parameter.
|
||||
return Err(DeflateError::Negotiation(
|
||||
NegotiationError::UnexpectedClientMaxWindowBits,
|
||||
));
|
||||
}
|
||||
|
||||
// Response with unknown parameter MUST fail the WebSocket connection.
|
||||
_ => {
|
||||
return Err(DeflateError::Negotiation(NegotiationError::UnknownParameter(
|
||||
key.to_owned(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(DeflateContext::new(Role::Client, config))
|
||||
}
|
||||
}
|
||||
|
||||
// A valid `client_max_window_bits` is no value or an integer in range `[8, 15]` without leading zeros.
|
||||
// A valid `server_max_window_bits` is an integer in range `[8, 15]` without leading zeros.
|
||||
fn is_valid_max_window_bits(bits: &str) -> bool {
|
||||
// Note that values from `headers::SecWebSocketExtensions` is unquoted.
|
||||
matches!(bits, "8" | "9" | "10" | "11" | "12" | "13" | "14" | "15")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::is_valid_max_window_bits;
|
||||
|
||||
#[test]
|
||||
fn valid_max_window_bits() {
|
||||
for bits in 8..=15 {
|
||||
assert!(is_valid_max_window_bits(&bits.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_max_window_bits() {
|
||||
assert!(!is_valid_max_window_bits(""));
|
||||
assert!(!is_valid_max_window_bits("0"));
|
||||
assert!(!is_valid_max_window_bits("08"));
|
||||
assert!(!is_valid_max_window_bits("+8"));
|
||||
assert!(!is_valid_max_window_bits("-8"));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Manages per message compression using DEFLATE.
|
||||
pub struct DeflateContext {
|
||||
role: Role,
|
||||
config: DeflateConfig,
|
||||
compressor: Compress,
|
||||
decompressor: Decompress,
|
||||
}
|
||||
|
||||
impl DeflateContext {
|
||||
fn new(role: Role, config: DeflateConfig) -> Self {
|
||||
DeflateContext {
|
||||
role,
|
||||
config,
|
||||
compressor: Compress::new(config.compression, false),
|
||||
decompressor: Decompress::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
fn own_context_takeover(&self) -> bool {
|
||||
match self.role {
|
||||
Role::Server => !self.config.server_no_context_takeover,
|
||||
Role::Client => !self.config.client_no_context_takeover,
|
||||
}
|
||||
}
|
||||
|
||||
fn peer_context_takeover(&self) -> bool {
|
||||
match self.role {
|
||||
Role::Server => !self.config.client_no_context_takeover,
|
||||
Role::Client => !self.config.server_no_context_takeover,
|
||||
}
|
||||
}
|
||||
|
||||
// Compress the data of message.
|
||||
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, DeflateError> {
|
||||
// https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1
|
||||
// 1. Compress all the octets of the payload of the message using DEFLATE.
|
||||
let mut output = Vec::with_capacity(data.len());
|
||||
let before_in = self.compressor.total_in() as usize;
|
||||
while (self.compressor.total_in() as usize) - before_in < data.len() {
|
||||
let offset = (self.compressor.total_in() as usize) - before_in;
|
||||
match self
|
||||
.compressor
|
||||
.compress_vec(&data[offset..], &mut output, FlushCompress::None)
|
||||
.map_err(|e| DeflateError::Compress(e.into()))?
|
||||
{
|
||||
Status::Ok => continue,
|
||||
Status::BufError => output.reserve(4096),
|
||||
Status::StreamEnd => break,
|
||||
}
|
||||
}
|
||||
// 2. If the resulting data does not end with an empty DEFLATE block
|
||||
// with no compression (the "BTYPE" bits are set to 00), append an
|
||||
// empty DEFLATE block with no compression to the tail end.
|
||||
while !output.ends_with(&TRAILER) {
|
||||
output.reserve(5);
|
||||
match self
|
||||
.compressor
|
||||
.compress_vec(&[], &mut output, FlushCompress::Sync)
|
||||
.map_err(|e| DeflateError::Compress(e.into()))?
|
||||
{
|
||||
Status::Ok | Status::BufError => continue,
|
||||
Status::StreamEnd => break,
|
||||
}
|
||||
}
|
||||
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end.
|
||||
// After this step, the last octet of the compressed data contains
|
||||
// (possibly part of) the DEFLATE header bits with the "BTYPE" bits
|
||||
// set to 00.
|
||||
output.truncate(output.len() - 4);
|
||||
|
||||
if !self.own_context_takeover() {
|
||||
self.compressor.reset();
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub(crate) fn decompress(
|
||||
&mut self,
|
||||
mut data: Vec<u8>,
|
||||
is_final: bool,
|
||||
) -> Result<Vec<u8>, DeflateError> {
|
||||
if is_final {
|
||||
data.extend_from_slice(&TRAILER);
|
||||
}
|
||||
|
||||
let before_in = self.decompressor.total_in() as usize;
|
||||
let mut output = Vec::with_capacity(2 * data.len());
|
||||
loop {
|
||||
let offset = (self.decompressor.total_in() as usize) - before_in;
|
||||
match self
|
||||
.decompressor
|
||||
.decompress_vec(&data[offset..], &mut output, FlushDecompress::None)
|
||||
.map_err(|e| DeflateError::Decompress(e.into()))?
|
||||
{
|
||||
Status::Ok => output.reserve(2 * output.len()),
|
||||
Status::BufError | Status::StreamEnd => break,
|
||||
}
|
||||
}
|
||||
|
||||
if is_final && !self.peer_context_takeover() {
|
||||
self.decompressor.reset(false);
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
fn to_header_value(params: &[HeaderValue]) -> WebsocketExtension {
|
||||
let mut buf = BytesMut::from(PER_MESSAGE_DEFLATE.as_bytes());
|
||||
for param in params {
|
||||
buf.extend_from_slice(b"; ");
|
||||
buf.extend_from_slice(param.as_bytes());
|
||||
}
|
||||
let header = HeaderValue::from_maybe_shared(buf.freeze())
|
||||
.expect("semicolon separated HeaderValue is valid");
|
||||
WebsocketExtension::try_from(header).expect("valid extension")
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
//! [Per-Message Compression Extensions][rfc7692]
|
||||
//!
|
||||
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
|
||||
pub mod deflate;
|
|
@ -0,0 +1,18 @@
|
|||
//! WebSocket extensions.
|
||||
// Only `permessage-deflate` is supported at the moment.
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
mod compression;
|
||||
#[cfg(feature = "deflate")]
|
||||
use compression::deflate::DeflateContext;
|
||||
#[cfg(feature = "deflate")]
|
||||
pub use compression::deflate::{DeflateConfig, DeflateError};
|
||||
|
||||
/// Container for configured extensions.
|
||||
#[derive(Debug, Default)]
|
||||
#[allow(missing_copy_implementations)]
|
||||
pub struct Extensions {
|
||||
// Per-Message Compression. Only `permessage-deflate` is supported.
|
||||
#[cfg(feature = "deflate")]
|
||||
pub(crate) compression: Option<DeflateContext>,
|
||||
}
|
|
@ -5,6 +5,7 @@ use std::{
|
|||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use headers::{HeaderMapExt, SecWebsocketExtensions};
|
||||
use http::{
|
||||
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||
};
|
||||
|
@ -19,6 +20,7 @@ use super::{
|
|||
};
|
||||
use crate::{
|
||||
error::{Error, ProtocolError, Result, UrlError},
|
||||
extensions::Extensions,
|
||||
protocol::{Role, WebSocket, WebSocketConfig},
|
||||
};
|
||||
|
||||
|
@ -56,7 +58,7 @@ impl<S: Read + Write> ClientHandshake<S> {
|
|||
|
||||
// Convert and verify the `http::Request` and turn it into the request as per RFC.
|
||||
// Also extract the key from it (it must be present in a correct request).
|
||||
let (request, key) = generate_request(request)?;
|
||||
let (request, key) = generate_request(request, &config)?;
|
||||
|
||||
let machine = HandshakeMachine::start_write(stream, request);
|
||||
|
||||
|
@ -83,18 +85,24 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
|||
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
|
||||
}
|
||||
StageResult::DoneReading { stream, result, tail } => {
|
||||
let result = match self.verify_data.verify_response(result) {
|
||||
Ok(r) => r,
|
||||
Err(Error::Http(mut e)) => {
|
||||
*e.body_mut() = Some(tail);
|
||||
return Err(Error::Http(e))
|
||||
},
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let (result, extensions) =
|
||||
match self.verify_data.verify_response(result, &self.config) {
|
||||
Ok(r) => r,
|
||||
Err(Error::Http(mut e)) => {
|
||||
*e.body_mut() = Some(tail);
|
||||
return Err(Error::Http(e));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
debug!("Client handshake done.");
|
||||
let websocket =
|
||||
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
|
||||
let websocket = WebSocket::from_partially_read_with_extensions(
|
||||
stream,
|
||||
tail,
|
||||
Role::Client,
|
||||
self.config,
|
||||
extensions,
|
||||
);
|
||||
ProcessingResult::Done((websocket, result))
|
||||
}
|
||||
})
|
||||
|
@ -102,7 +110,10 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
|||
}
|
||||
|
||||
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
|
||||
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
||||
pub fn generate_request(
|
||||
mut request: Request,
|
||||
config: &Option<WebSocketConfig>,
|
||||
) -> Result<(Vec<u8>, String)> {
|
||||
let mut req = Vec::new();
|
||||
write!(
|
||||
req,
|
||||
|
@ -173,6 +184,9 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
|||
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
|
||||
}
|
||||
|
||||
if let Some(offers) = config.and_then(|c| c.generate_offers()) {
|
||||
writeln!(req, "Sec-WebSocket-Extensions: {}\r", offers.to_value().to_str()?).unwrap();
|
||||
}
|
||||
writeln!(req, "\r").unwrap();
|
||||
trace!("Request: {:?}", String::from_utf8_lossy(&req));
|
||||
Ok((req, key))
|
||||
|
@ -186,7 +200,11 @@ struct VerifyData {
|
|||
}
|
||||
|
||||
impl VerifyData {
|
||||
pub fn verify_response(&self, response: Response) -> Result<Response> {
|
||||
pub fn verify_response(
|
||||
&self,
|
||||
response: Response,
|
||||
_config: &Option<WebSocketConfig>,
|
||||
) -> Result<(Response, Option<Extensions>)> {
|
||||
// 1. If the status code received from the server is not 101, the
|
||||
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
|
||||
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
|
||||
|
@ -231,7 +249,14 @@ impl VerifyData {
|
|||
// that was not present in the client's handshake (the server has
|
||||
// indicated an extension not requested by the client), the client
|
||||
// MUST _Fail the WebSocket Connection_. (RFC 6455)
|
||||
// TODO
|
||||
let extensions = if let Some(agreed) = headers
|
||||
.typed_try_get::<SecWebsocketExtensions>()
|
||||
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
|
||||
{
|
||||
verify_extensions(&agreed, _config)?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
|
||||
// and this header field indicates the use of a subprotocol that was
|
||||
|
@ -240,10 +265,49 @@ impl VerifyData {
|
|||
// the WebSocket Connection_. (RFC 6455)
|
||||
// TODO
|
||||
|
||||
Ok(response)
|
||||
Ok((response, extensions))
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_extensions(
|
||||
agreed_extensions: &headers::SecWebsocketExtensions,
|
||||
_config: &Option<WebSocketConfig>,
|
||||
) -> Result<Option<Extensions>> {
|
||||
#[cfg(feature = "deflate")]
|
||||
{
|
||||
if let Some(compression) = _config.and_then(|c| c.compression) {
|
||||
let mut extensions = None;
|
||||
for extension in agreed_extensions.iter() {
|
||||
// > If a server gives an invalid response, such as accepting a PMCE that the client did not offer,
|
||||
// > the client MUST _Fail the WebSocket Connection_.
|
||||
if extension.name() != compression.name() {
|
||||
return Err(Error::Protocol(ProtocolError::InvalidExtension(
|
||||
extension.name().to_string(),
|
||||
)));
|
||||
}
|
||||
|
||||
// Already had PMCE configured
|
||||
if extensions.is_some() {
|
||||
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
|
||||
extension.name().to_string(),
|
||||
)));
|
||||
}
|
||||
|
||||
extensions = Some(Extensions {
|
||||
compression: Some(compression.accept_response(extension.params())?),
|
||||
});
|
||||
}
|
||||
return Ok(extensions);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(extension) = agreed_extensions.iter().next() {
|
||||
// The client didn't request anything, but got something
|
||||
return Err(Error::Protocol(ProtocolError::InvalidExtension(extension.name().to_string())));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
impl TryParse for Response {
|
||||
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||
|
@ -286,6 +350,8 @@ pub fn generate_key() -> String {
|
|||
mod tests {
|
||||
use super::{super::machine::TryParse, generate_key, generate_request, Response};
|
||||
use crate::client::IntoClientRequest;
|
||||
#[cfg(feature = "deflate")]
|
||||
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
|
||||
|
||||
#[test]
|
||||
fn random_keys() {
|
||||
|
@ -322,7 +388,7 @@ mod tests {
|
|||
#[test]
|
||||
fn request_formatting() {
|
||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, &None).unwrap();
|
||||
let correct = construct_expected("localhost", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
@ -330,7 +396,7 @@ mod tests {
|
|||
#[test]
|
||||
fn request_formatting_with_host() {
|
||||
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, &None).unwrap();
|
||||
let correct = construct_expected("localhost:9001", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
@ -338,11 +404,40 @@ mod tests {
|
|||
#[test]
|
||||
fn request_formatting_with_at() {
|
||||
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let (request, key) = generate_request(request, &None).unwrap();
|
||||
let correct = construct_expected("localhost:9001", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
#[test]
|
||||
fn request_with_compression() {
|
||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(
|
||||
request,
|
||||
&Some(WebSocketConfig {
|
||||
compression: Some(DeflateConfig::default()),
|
||||
..WebSocketConfig::default()
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
let correct = format!(
|
||||
"\
|
||||
GET /getCaseCount HTTP/1.1\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: {key}\r\n\
|
||||
Sec-WebSocket-Extensions: permessage-deflate\r\n\
|
||||
\r\n",
|
||||
host = "localhost",
|
||||
key = key
|
||||
)
|
||||
.into_bytes();
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_parsing() {
|
||||
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
|
||||
|
@ -354,6 +449,6 @@ mod tests {
|
|||
#[test]
|
||||
fn invalid_custom_request() {
|
||||
let request = http::Request::builder().method("GET").body(()).unwrap();
|
||||
assert!(generate_request(request).is_err());
|
||||
assert!(generate_request(request, &None).is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ use std::{
|
|||
result::Result as StdResult,
|
||||
};
|
||||
|
||||
use headers::{HeaderMapExt, SecWebsocketExtensions};
|
||||
use http::{
|
||||
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||
};
|
||||
|
@ -20,6 +21,7 @@ use super::{
|
|||
};
|
||||
use crate::{
|
||||
error::{Error, ProtocolError, Result},
|
||||
extensions::Extensions,
|
||||
protocol::{Role, WebSocket, WebSocketConfig},
|
||||
};
|
||||
|
||||
|
@ -202,6 +204,8 @@ pub struct ServerHandshake<S, C> {
|
|||
config: Option<WebSocketConfig>,
|
||||
/// Error code/flag. If set, an error will be returned after sending response to the client.
|
||||
error_response: Option<ErrorResponse>,
|
||||
// Negotiated extension context for server.
|
||||
extensions: Option<Extensions>,
|
||||
/// Internal stream type.
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
@ -219,6 +223,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
|
|||
callback: Some(callback),
|
||||
config,
|
||||
error_response: None,
|
||||
extensions: None,
|
||||
_marker: PhantomData,
|
||||
},
|
||||
}
|
||||
|
@ -240,7 +245,19 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
|||
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
|
||||
}
|
||||
|
||||
let response = create_response(&result)?;
|
||||
let mut response = create_response(&result)?;
|
||||
if let Some(config) = &self.config {
|
||||
if let Some((agreed, extensions)) = result
|
||||
.headers()
|
||||
.typed_try_get::<SecWebsocketExtensions>()
|
||||
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
|
||||
.and_then(|values| config.accept_offers(&values))
|
||||
{
|
||||
response.headers_mut().typed_insert(agreed);
|
||||
self.extensions = Some(extensions);
|
||||
}
|
||||
}
|
||||
|
||||
let callback_result = if let Some(callback) = self.callback.take() {
|
||||
callback.on_request(&result, response)
|
||||
} else {
|
||||
|
@ -283,7 +300,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
|||
return Err(Error::Http(http::Response::from_parts(parts, body)));
|
||||
} else {
|
||||
debug!("Server handshake done.");
|
||||
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
|
||||
let websocket = WebSocket::from_raw_socket_with_extensions(
|
||||
stream,
|
||||
Role::Server,
|
||||
self.config,
|
||||
self.extensions.take(),
|
||||
);
|
||||
ProcessingResult::Done(websocket)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ pub mod buffer;
|
|||
#[cfg(feature = "handshake")]
|
||||
pub mod client;
|
||||
pub mod error;
|
||||
pub mod extensions;
|
||||
#[cfg(feature = "handshake")]
|
||||
pub mod handshake;
|
||||
pub mod protocol;
|
||||
|
|
|
@ -311,6 +311,18 @@ impl Frame {
|
|||
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
|
||||
}
|
||||
|
||||
/// Create a new compressed data frame.
|
||||
#[inline]
|
||||
#[cfg(feature = "deflate")]
|
||||
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
|
||||
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
|
||||
|
||||
Frame {
|
||||
header: FrameHeader { is_final, opcode, rsv1: true, ..FrameHeader::default() },
|
||||
payload: data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Pong control frame.
|
||||
#[inline]
|
||||
pub fn pong(data: Vec<u8>) -> Frame {
|
||||
|
|
|
@ -84,6 +84,8 @@ use self::string_collect::StringCollector;
|
|||
#[derive(Debug)]
|
||||
pub struct IncompleteMessage {
|
||||
collector: IncompleteMessageCollector,
|
||||
#[cfg(feature = "deflate")]
|
||||
compressed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -94,6 +96,7 @@ enum IncompleteMessageCollector {
|
|||
|
||||
impl IncompleteMessage {
|
||||
/// Create new.
|
||||
#[cfg(not(feature = "deflate"))]
|
||||
pub fn new(message_type: IncompleteMessageType) -> Self {
|
||||
IncompleteMessage {
|
||||
collector: match message_type {
|
||||
|
@ -105,6 +108,25 @@ impl IncompleteMessage {
|
|||
}
|
||||
}
|
||||
|
||||
/// Create new.
|
||||
#[cfg(feature = "deflate")]
|
||||
pub fn new(message_type: IncompleteMessageType, compressed: bool) -> Self {
|
||||
IncompleteMessage {
|
||||
collector: match message_type {
|
||||
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
|
||||
IncompleteMessageType::Text => {
|
||||
IncompleteMessageCollector::Text(StringCollector::new())
|
||||
}
|
||||
},
|
||||
compressed,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "deflate")]
|
||||
pub fn compressed(&self) -> bool {
|
||||
self.compressed
|
||||
}
|
||||
|
||||
/// Get the current filled size of the buffer.
|
||||
pub fn len(&self) -> usize {
|
||||
match self.collector {
|
||||
|
|
|
@ -22,6 +22,7 @@ use self::{
|
|||
};
|
||||
use crate::{
|
||||
error::{Error, ProtocolError, Result},
|
||||
extensions::Extensions,
|
||||
util::NonBlockingResult,
|
||||
};
|
||||
|
||||
|
@ -56,6 +57,9 @@ pub struct WebSocketConfig {
|
|||
/// some popular libraries that are sending unmasked frames, ignoring the RFC.
|
||||
/// By default this option is set to `false`, i.e. according to RFC 6455.
|
||||
pub accept_unmasked_frames: bool,
|
||||
/// Optional configuration for Per-Message Compression Extension.
|
||||
#[cfg(feature = "deflate")]
|
||||
pub compression: Option<crate::extensions::DeflateConfig>,
|
||||
}
|
||||
|
||||
impl Default for WebSocketConfig {
|
||||
|
@ -65,6 +69,64 @@ impl Default for WebSocketConfig {
|
|||
max_message_size: Some(64 << 20),
|
||||
max_frame_size: Some(16 << 20),
|
||||
accept_unmasked_frames: false,
|
||||
#[cfg(feature = "deflate")]
|
||||
compression: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSocketConfig {
|
||||
// Generate extension negotiation offers for configured extensions.
|
||||
// Only `permessage-deflate` is supported at the moment.
|
||||
pub(crate) fn generate_offers(&self) -> Option<headers::SecWebsocketExtensions> {
|
||||
#[cfg(feature = "deflate")]
|
||||
{
|
||||
let mut offers = Vec::new();
|
||||
if let Some(compression) = self.compression.map(|c| c.generate_offer()) {
|
||||
offers.push(compression);
|
||||
}
|
||||
if offers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(headers::SecWebsocketExtensions::new(offers))
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "deflate"))]
|
||||
{
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
|
||||
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
|
||||
pub fn accept_offers(
|
||||
&self,
|
||||
_offers: &headers::SecWebsocketExtensions,
|
||||
) -> Option<(headers::SecWebsocketExtensions, Extensions)> {
|
||||
#[cfg(feature = "deflate")]
|
||||
{
|
||||
// To support more extensions, store extension context in `Extensions` and
|
||||
// concatenate negotiation responses from each extension.
|
||||
let mut agreed_extensions = Vec::new();
|
||||
let mut extensions = Extensions::default();
|
||||
|
||||
if let Some(compression) = &self.compression {
|
||||
if let Some((agreed, compression)) = compression.accept_offer(_offers) {
|
||||
agreed_extensions.push(agreed);
|
||||
extensions.compression = Some(compression);
|
||||
}
|
||||
}
|
||||
|
||||
if agreed_extensions.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((headers::SecWebsocketExtensions::new(agreed_extensions), extensions))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "deflate"))]
|
||||
{
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -91,6 +153,18 @@ impl<Stream> WebSocket<Stream> {
|
|||
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
|
||||
}
|
||||
|
||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||
pub fn from_raw_socket_with_extensions(
|
||||
stream: Stream,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Option<Extensions>,
|
||||
) -> Self {
|
||||
let mut context = WebSocketContext::new(role, config);
|
||||
context.extensions = extensions;
|
||||
WebSocket { socket: stream, context }
|
||||
}
|
||||
|
||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||
///
|
||||
/// Call this function if you're using Tungstenite as a part of a web framework
|
||||
|
@ -108,6 +182,21 @@ impl<Stream> WebSocket<Stream> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_partially_read_with_extensions(
|
||||
stream: Stream,
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Option<Extensions>,
|
||||
) -> Self {
|
||||
WebSocket {
|
||||
socket: stream,
|
||||
context: WebSocketContext::from_partially_read_with_extensions(
|
||||
part, role, config, extensions,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a shared reference to the inner stream.
|
||||
pub fn get_ref(&self) -> &Stream {
|
||||
&self.socket
|
||||
|
@ -241,6 +330,8 @@ pub struct WebSocketContext {
|
|||
pong: Option<Frame>,
|
||||
/// The configuration for the websocket session.
|
||||
config: WebSocketConfig,
|
||||
// Container for extensions.
|
||||
pub(crate) extensions: Option<Extensions>,
|
||||
}
|
||||
|
||||
impl WebSocketContext {
|
||||
|
@ -254,6 +345,7 @@ impl WebSocketContext {
|
|||
send_queue: VecDeque::new(),
|
||||
pong: None,
|
||||
config: config.unwrap_or_default(),
|
||||
extensions: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -265,6 +357,19 @@ impl WebSocketContext {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_partially_read_with_extensions(
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
extensions: Option<Extensions>,
|
||||
) -> Self {
|
||||
WebSocketContext {
|
||||
frame: FrameCodec::from_partially_read(part),
|
||||
extensions,
|
||||
..WebSocketContext::new(role, config)
|
||||
}
|
||||
}
|
||||
|
||||
/// Change the configuration.
|
||||
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
|
||||
set_func(&mut self.config)
|
||||
|
@ -348,8 +453,8 @@ impl WebSocketContext {
|
|||
}
|
||||
|
||||
let frame = match message {
|
||||
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
|
||||
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
|
||||
Message::Text(data) => self.prepare_data_frame(data.into(), OpData::Text)?,
|
||||
Message::Binary(data) => self.prepare_data_frame(data, OpData::Binary)?,
|
||||
Message::Ping(data) => Frame::ping(data),
|
||||
Message::Pong(data) => {
|
||||
self.pong = Some(Frame::pong(data));
|
||||
|
@ -363,6 +468,17 @@ impl WebSocketContext {
|
|||
self.write_pending(stream)
|
||||
}
|
||||
|
||||
fn prepare_data_frame(&mut self, data: Vec<u8>, opdata: OpData) -> Result<Frame> {
|
||||
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
|
||||
let opcode = OpCode::Data(opdata);
|
||||
let is_final = true;
|
||||
#[cfg(feature = "deflate")]
|
||||
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
|
||||
return Ok(Frame::compressed_message(pmce.compress(&data)?, opcode, is_final));
|
||||
}
|
||||
Ok(Frame::message(data, opcode, is_final))
|
||||
}
|
||||
|
||||
/// Flush the pending send queue.
|
||||
pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
|
||||
where
|
||||
|
@ -439,12 +555,14 @@ impl WebSocketContext {
|
|||
// the negotiated extensions defines the meaning of such a nonzero
|
||||
// value, the receiving endpoint MUST _Fail the WebSocket
|
||||
// Connection_.
|
||||
{
|
||||
let is_compressed = {
|
||||
let hdr = frame.header();
|
||||
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
|
||||
if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
|
||||
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
|
||||
}
|
||||
}
|
||||
|
||||
hdr.rsv1
|
||||
};
|
||||
|
||||
match self.role {
|
||||
Role::Server => {
|
||||
|
@ -479,6 +597,10 @@ impl WebSocketContext {
|
|||
_ if frame.payload().len() > 125 => {
|
||||
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
|
||||
}
|
||||
// Control frames must not have compress bit.
|
||||
_ if is_compressed => {
|
||||
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
|
||||
}
|
||||
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
|
||||
OpCtl::Reserved(i) => {
|
||||
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
|
||||
|
@ -499,39 +621,34 @@ impl WebSocketContext {
|
|||
let fin = frame.header().is_final;
|
||||
match data {
|
||||
OpData::Continue => {
|
||||
if let Some(ref mut msg) = self.incomplete {
|
||||
msg.extend(frame.into_data(), self.config.max_message_size)?;
|
||||
} else {
|
||||
if self.incomplete.is_some() && is_compressed {
|
||||
return Err(Error::Protocol(
|
||||
ProtocolError::UnexpectedContinueFrame,
|
||||
ProtocolError::CompressedContinueFrame,
|
||||
));
|
||||
}
|
||||
if fin {
|
||||
Ok(Some(self.incomplete.take().unwrap().complete()?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
let msg = self
|
||||
.incomplete
|
||||
.take()
|
||||
.ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?;
|
||||
self.extend_incomplete(msg, frame.into_data(), fin)
|
||||
}
|
||||
|
||||
c if self.incomplete.is_some() => {
|
||||
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
|
||||
}
|
||||
|
||||
OpData::Text | OpData::Binary => {
|
||||
let msg = {
|
||||
let message_type = match data {
|
||||
OpData::Text => IncompleteMessageType::Text,
|
||||
OpData::Binary => IncompleteMessageType::Binary,
|
||||
_ => panic!("Bug: message is not text nor binary"),
|
||||
};
|
||||
let mut m = IncompleteMessage::new(message_type);
|
||||
m.extend(frame.into_data(), self.config.max_message_size)?;
|
||||
m
|
||||
let message_type = match data {
|
||||
OpData::Text => IncompleteMessageType::Text,
|
||||
OpData::Binary => IncompleteMessageType::Binary,
|
||||
_ => panic!("Bug: message is not text nor binary"),
|
||||
};
|
||||
if fin {
|
||||
Ok(Some(msg.complete()?))
|
||||
} else {
|
||||
self.incomplete = Some(msg);
|
||||
Ok(None)
|
||||
}
|
||||
#[cfg(feature = "deflate")]
|
||||
let msg = IncompleteMessage::new(message_type, is_compressed);
|
||||
#[cfg(not(feature = "deflate"))]
|
||||
let msg = IncompleteMessage::new(message_type);
|
||||
self.extend_incomplete(msg, frame.into_data(), fin)
|
||||
}
|
||||
OpData::Reserved(i) => {
|
||||
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
|
||||
|
@ -550,6 +667,32 @@ impl WebSocketContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn extend_incomplete(
|
||||
&mut self,
|
||||
mut msg: IncompleteMessage,
|
||||
data: Vec<u8>,
|
||||
is_final: bool,
|
||||
) -> Result<Option<Message>> {
|
||||
#[cfg(feature = "deflate")]
|
||||
let data = if msg.compressed() {
|
||||
// `msg.compressed()` is only true when compression is enabled so it's safe to unwrap
|
||||
self.extensions
|
||||
.as_mut()
|
||||
.and_then(|x| x.compression.as_mut())
|
||||
.unwrap()
|
||||
.decompress(data, is_final)?
|
||||
} else {
|
||||
data
|
||||
};
|
||||
msg.extend(data, self.config.max_message_size)?;
|
||||
if is_final {
|
||||
Ok(Some(msg.complete()?))
|
||||
} else {
|
||||
self.incomplete = Some(msg);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Received a close frame. Tells if we need to return a close frame to the user.
|
||||
#[allow(clippy::option_option)]
|
||||
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
|
||||
|
@ -605,6 +748,17 @@ impl WebSocketContext {
|
|||
trace!("Sending frame: {:?}", frame);
|
||||
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
|
||||
}
|
||||
|
||||
fn has_compression(&self) -> bool {
|
||||
#[cfg(feature = "deflate")]
|
||||
{
|
||||
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
|
||||
}
|
||||
#[cfg(not(feature = "deflate"))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The current connection state.
|
||||
|
|
Loading…
Reference in New Issue