Revert "Add `permessage-deflate` support"
This reverts commit edb2377540
.
See https://github.com/snapview/tungstenite-rs/pull/328#issuecomment-1480349206
This commit is contained in:
parent
edb2377540
commit
42b8797e8b
|
@ -1,4 +1,2 @@
|
||||||
target
|
target
|
||||||
Cargo.lock
|
Cargo.lock
|
||||||
autobahn/client/
|
|
||||||
autobahn/server/
|
|
||||||
|
|
|
@ -10,6 +10,5 @@ before_script:
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- cargo test --release
|
- cargo test --release
|
||||||
- cargo test --release --features=deflate
|
|
||||||
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
|
- echo "Running Autobahn TestSuite for client" && ./scripts/autobahn-client.sh
|
||||||
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh
|
- echo "Running Autobahn TestSuite for server" && ./scripts/autobahn-server.sh
|
||||||
|
|
14
Cargo.toml
14
Cargo.toml
|
@ -25,15 +25,6 @@ native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
|
||||||
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
|
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
|
||||||
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
|
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
|
||||||
__rustls-tls = ["rustls", "webpki"]
|
__rustls-tls = ["rustls", "webpki"]
|
||||||
deflate = ["flate2"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "autobahn-client"
|
|
||||||
required-features = ["deflate"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "autobahn-server"
|
|
||||||
required-features = ["deflate"]
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
data-encoding = { version = "2", optional = true }
|
data-encoding = { version = "2", optional = true }
|
||||||
|
@ -47,11 +38,6 @@ sha1 = { version = "0.10", optional = true }
|
||||||
thiserror = "1.0.23"
|
thiserror = "1.0.23"
|
||||||
url = { version = "2.1.0", optional = true }
|
url = { version = "2.1.0", optional = true }
|
||||||
utf-8 = "0.7.5"
|
utf-8 = "0.7.5"
|
||||||
headers = { git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" }
|
|
||||||
|
|
||||||
[dependencies.flate2]
|
|
||||||
optional = true
|
|
||||||
version = "1.0"
|
|
||||||
|
|
||||||
[dependencies.native-tls-crate]
|
[dependencies.native-tls-crate]
|
||||||
optional = true
|
optional = true
|
||||||
|
|
|
@ -72,6 +72,8 @@ Choose the one that is appropriate for your needs.
|
||||||
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
|
By default **no TLS feature is activated**, so make sure you use one of the TLS features,
|
||||||
otherwise you won't be able to communicate with the TLS endpoints.
|
otherwise you won't be able to communicate with the TLS endpoints.
|
||||||
|
|
||||||
|
There is no support for permessage-deflate at the moment, but the PRs are welcome :wink:
|
||||||
|
|
||||||
Testing
|
Testing
|
||||||
-------
|
-------
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,10 +1,7 @@
|
||||||
use log::*;
|
use log::*;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use tungstenite::{
|
use tungstenite::{connect, Error, Message, Result};
|
||||||
client::connect_with_config, connect, extensions::DeflateConfig, protocol::WebSocketConfig,
|
|
||||||
Error, Message, Result,
|
|
||||||
};
|
|
||||||
|
|
||||||
const AGENT: &str = "Tungstenite";
|
const AGENT: &str = "Tungstenite";
|
||||||
|
|
||||||
|
@ -27,14 +24,7 @@ fn run_test(case: u32) -> Result<()> {
|
||||||
info!("Running test case {}", case);
|
info!("Running test case {}", case);
|
||||||
let case_url =
|
let case_url =
|
||||||
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
|
Url::parse(&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)).unwrap();
|
||||||
let (mut socket, _) = connect_with_config(
|
let (mut socket, _) = connect(case_url)?;
|
||||||
case_url,
|
|
||||||
Some(WebSocketConfig {
|
|
||||||
compression: Some(DeflateConfig::default()),
|
|
||||||
..WebSocketConfig::default()
|
|
||||||
}),
|
|
||||||
3,
|
|
||||||
)?;
|
|
||||||
loop {
|
loop {
|
||||||
match socket.read_message()? {
|
match socket.read_message()? {
|
||||||
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
|
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
|
||||||
|
|
|
@ -4,10 +4,7 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use log::*;
|
use log::*;
|
||||||
use tungstenite::{
|
use tungstenite::{accept, handshake::HandshakeRole, Error, HandshakeError, Message, Result};
|
||||||
accept_with_config, extensions::DeflateConfig, handshake::HandshakeRole,
|
|
||||||
protocol::WebSocketConfig, Error, HandshakeError, Message, Result,
|
|
||||||
};
|
|
||||||
|
|
||||||
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
|
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
|
||||||
match err {
|
match err {
|
||||||
|
@ -17,14 +14,7 @@ fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_client(stream: TcpStream) -> Result<()> {
|
fn handle_client(stream: TcpStream) -> Result<()> {
|
||||||
let mut socket = accept_with_config(
|
let mut socket = accept(stream).map_err(must_not_block)?;
|
||||||
stream,
|
|
||||||
Some(WebSocketConfig {
|
|
||||||
compression: Some(DeflateConfig::default()),
|
|
||||||
..WebSocketConfig::default()
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.map_err(must_not_block)?;
|
|
||||||
info!("Running test");
|
info!("Running test");
|
||||||
loop {
|
loop {
|
||||||
match socket.read_message()? {
|
match socket.read_message()? {
|
||||||
|
|
|
@ -35,8 +35,6 @@ fn main() {
|
||||||
// rare cases where it is necessary to integrate with existing/legacy
|
// rare cases where it is necessary to integrate with existing/legacy
|
||||||
// clients which are sending unmasked frames
|
// clients which are sending unmasked frames
|
||||||
accept_unmasked_frames: true,
|
accept_unmasked_frames: true,
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
compression: None,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();
|
let mut websocket = accept_hdr_with_config(stream.unwrap(), callback, config).unwrap();
|
||||||
|
|
|
@ -32,5 +32,5 @@ docker run -d --rm \
|
||||||
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
|
wstest -m fuzzingserver -s 'autobahn/fuzzingserver.json'
|
||||||
|
|
||||||
sleep 3
|
sleep 3
|
||||||
cargo run --release --example autobahn-client --features=deflate
|
cargo run --release --example autobahn-client
|
||||||
test_diff
|
test_diff
|
||||||
|
|
|
@ -22,7 +22,7 @@ function test_diff() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
cargo run --release --example autobahn-server --features=deflate & WSSERVER_PID=$!
|
cargo run --release --example autobahn-server & WSSERVER_PID=$!
|
||||||
sleep 3
|
sleep 3
|
||||||
|
|
||||||
docker run --rm \
|
docker run --rm \
|
||||||
|
|
19
src/error.rs
19
src/error.rs
|
@ -70,10 +70,6 @@ pub enum Error {
|
||||||
#[error("HTTP format error: {0}")]
|
#[error("HTTP format error: {0}")]
|
||||||
#[cfg(feature = "handshake")]
|
#[cfg(feature = "handshake")]
|
||||||
HttpFormat(#[from] http::Error),
|
HttpFormat(#[from] http::Error),
|
||||||
/// Error from `permessage-deflate` extension.
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
#[error("Deflate error: {0}")]
|
|
||||||
Deflate(#[from] crate::extensions::DeflateError),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<str::Utf8Error> for Error {
|
impl From<str::Utf8Error> for Error {
|
||||||
|
@ -210,9 +206,6 @@ pub enum ProtocolError {
|
||||||
/// Control frames must not be fragmented.
|
/// Control frames must not be fragmented.
|
||||||
#[error("Fragmented control frame")]
|
#[error("Fragmented control frame")]
|
||||||
FragmentedControlFrame,
|
FragmentedControlFrame,
|
||||||
/// Control frames must not be compressed.
|
|
||||||
#[error("Compressed control frame")]
|
|
||||||
CompressedControlFrame,
|
|
||||||
/// Control frames must have a payload of 125 bytes or less.
|
/// Control frames must have a payload of 125 bytes or less.
|
||||||
#[error("Control frame too big (payload must be 125 bytes or less)")]
|
#[error("Control frame too big (payload must be 125 bytes or less)")]
|
||||||
ControlFrameTooBig,
|
ControlFrameTooBig,
|
||||||
|
@ -225,9 +218,6 @@ pub enum ProtocolError {
|
||||||
/// Received a continue frame despite there being nothing to continue.
|
/// Received a continue frame despite there being nothing to continue.
|
||||||
#[error("Continue frame but nothing to continue")]
|
#[error("Continue frame but nothing to continue")]
|
||||||
UnexpectedContinueFrame,
|
UnexpectedContinueFrame,
|
||||||
/// Received a compressed continue frame.
|
|
||||||
#[error("Continue frame must not have compress bit set")]
|
|
||||||
CompressedContinueFrame,
|
|
||||||
/// Received data while waiting for more fragments.
|
/// Received data while waiting for more fragments.
|
||||||
#[error("While waiting for more fragments received: {0}")]
|
#[error("While waiting for more fragments received: {0}")]
|
||||||
ExpectedFragment(Data),
|
ExpectedFragment(Data),
|
||||||
|
@ -240,15 +230,6 @@ pub enum ProtocolError {
|
||||||
/// The payload for the closing frame is invalid.
|
/// The payload for the closing frame is invalid.
|
||||||
#[error("Invalid close sequence")]
|
#[error("Invalid close sequence")]
|
||||||
InvalidCloseSequence,
|
InvalidCloseSequence,
|
||||||
/// The negotiation response included an extension not offered.
|
|
||||||
#[error("Extension negotiation response had invalid extension: {0}")]
|
|
||||||
InvalidExtension(String),
|
|
||||||
/// The negotiation response included an extension more than once.
|
|
||||||
#[error("Extension negotiation response had conflicting extension: {0}")]
|
|
||||||
ExtensionConflict(String),
|
|
||||||
/// The `Sec-WebSocket-Extensions` header is invalid.
|
|
||||||
#[error("Invalid \"Sec-WebSocket-Extensions\" header")]
|
|
||||||
InvalidExtensionsHeader,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Indicates the specific type/cause of URL error.
|
/// Indicates the specific type/cause of URL error.
|
||||||
|
|
|
@ -1,442 +0,0 @@
|
||||||
use std::convert::TryFrom;
|
|
||||||
|
|
||||||
use bytes::BytesMut;
|
|
||||||
use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
|
|
||||||
use headers::WebsocketExtension;
|
|
||||||
use http::HeaderValue;
|
|
||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
use crate::protocol::Role;
|
|
||||||
|
|
||||||
const PER_MESSAGE_DEFLATE: &str = "permessage-deflate";
|
|
||||||
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
|
|
||||||
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
|
|
||||||
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
|
|
||||||
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
|
|
||||||
|
|
||||||
const TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
|
|
||||||
|
|
||||||
/// Errors from `permessage-deflate` extension.
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum DeflateError {
|
|
||||||
/// Compress failed
|
|
||||||
#[error("Failed to compress")]
|
|
||||||
Compress(#[source] std::io::Error),
|
|
||||||
/// Decompress failed
|
|
||||||
#[error("Failed to decompress")]
|
|
||||||
Decompress(#[source] std::io::Error),
|
|
||||||
|
|
||||||
/// Extension negotiation failed.
|
|
||||||
#[error("Extension negotiation failed")]
|
|
||||||
Negotiation(#[source] NegotiationError),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Errors from `permessage-deflate` extension negotiation.
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum NegotiationError {
|
|
||||||
/// Unknown parameter in a negotiation response.
|
|
||||||
#[error("Unknown parameter in a negotiation response: {0}")]
|
|
||||||
UnknownParameter(String),
|
|
||||||
/// Duplicate parameter in a negotiation response.
|
|
||||||
#[error("Duplicate parameter in a negotiation response: {0}")]
|
|
||||||
DuplicateParameter(String),
|
|
||||||
/// Received `client_max_window_bits` in a negotiation response for an offer without it.
|
|
||||||
#[error("Received client_max_window_bits in a negotiation response for an offer without it")]
|
|
||||||
UnexpectedClientMaxWindowBits,
|
|
||||||
/// Received unsupported `server_max_window_bits` in a negotiation response.
|
|
||||||
#[error("Received unsupported server_max_window_bits in a negotiation response")]
|
|
||||||
ServerMaxWindowBitsNotSupported,
|
|
||||||
/// Invalid `client_max_window_bits` value in a negotiation response.
|
|
||||||
#[error("Invalid client_max_window_bits value in a negotiation response: {0}")]
|
|
||||||
InvalidClientMaxWindowBitsValue(String),
|
|
||||||
/// Invalid `server_max_window_bits` value in a negotiation response.
|
|
||||||
#[error("Invalid server_max_window_bits value in a negotiation response: {0}")]
|
|
||||||
InvalidServerMaxWindowBitsValue(String),
|
|
||||||
/// Missing `server_max_window_bits` value in a negotiation response.
|
|
||||||
#[error("Missing server_max_window_bits value in a negotiation response")]
|
|
||||||
MissingServerMaxWindowBitsValue,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parameters `server_max_window_bits` and `client_max_window_bits` are not supported for now
|
|
||||||
// because custom window size requires `flate2/zlib` feature.
|
|
||||||
/// Configurations for `permessage-deflate` Per-Message Compression Extension.
|
|
||||||
#[derive(Clone, Copy, Debug, Default)]
|
|
||||||
pub struct DeflateConfig {
|
|
||||||
/// Compression level.
|
|
||||||
pub compression: Compression,
|
|
||||||
/// Request the peer server not to use context takeover.
|
|
||||||
pub server_no_context_takeover: bool,
|
|
||||||
/// Hint that context takeover is not used.
|
|
||||||
pub client_no_context_takeover: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DeflateConfig {
|
|
||||||
pub(crate) fn name(&self) -> &str {
|
|
||||||
PER_MESSAGE_DEFLATE
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Value for `Sec-WebSocket-Extensions` request header.
|
|
||||||
pub(crate) fn generate_offer(&self) -> WebsocketExtension {
|
|
||||||
let mut offers = Vec::new();
|
|
||||||
if self.server_no_context_takeover {
|
|
||||||
offers.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
|
|
||||||
}
|
|
||||||
|
|
||||||
// > a client informs the peer server of a hint that even if the server doesn't include the
|
|
||||||
// > "client_no_context_takeover" extension parameter in the corresponding
|
|
||||||
// > extension negotiation response to the offer, the client is not going
|
|
||||||
// > to use context takeover.
|
|
||||||
// > https://www.rfc-editor.org/rfc/rfc7692#section-7.1.1.2
|
|
||||||
if self.client_no_context_takeover {
|
|
||||||
offers.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
|
|
||||||
}
|
|
||||||
to_header_value(&offers)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns negotiation response based on offers and `DeflateContext` to manage per message compression.
|
|
||||||
pub(crate) fn accept_offer(
|
|
||||||
&self,
|
|
||||||
offers: &headers::SecWebsocketExtensions,
|
|
||||||
) -> Option<(WebsocketExtension, DeflateContext)> {
|
|
||||||
// Accept the first valid offer for `permessage-deflate`.
|
|
||||||
// A server MUST decline an extension negotiation offer for this
|
|
||||||
// extension if any of the following conditions are met:
|
|
||||||
// 1. The negotiation offer contains an extension parameter not defined for use in an offer.
|
|
||||||
// 2. The negotiation offer contains an extension parameter with an invalid value.
|
|
||||||
// 3. The negotiation offer contains multiple extension parameters with the same name.
|
|
||||||
// 4. The server doesn't support the offered configuration.
|
|
||||||
offers.iter().find_map(|extension| {
|
|
||||||
if let Some(params) = (extension.name() == self.name()).then(|| extension.params()) {
|
|
||||||
let mut config =
|
|
||||||
DeflateConfig { compression: self.compression, ..DeflateConfig::default() };
|
|
||||||
let mut agreed = Vec::new();
|
|
||||||
let mut seen_server_no_context_takeover = false;
|
|
||||||
let mut seen_client_no_context_takeover = false;
|
|
||||||
let mut seen_client_max_window_bits = false;
|
|
||||||
for (key, val) in params {
|
|
||||||
match key {
|
|
||||||
SERVER_NO_CONTEXT_TAKEOVER => {
|
|
||||||
// Invalid offer with multiple params with same name is declined.
|
|
||||||
if seen_server_no_context_takeover {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
seen_server_no_context_takeover = true;
|
|
||||||
config.server_no_context_takeover = true;
|
|
||||||
agreed.push(HeaderValue::from_static(SERVER_NO_CONTEXT_TAKEOVER));
|
|
||||||
}
|
|
||||||
|
|
||||||
CLIENT_NO_CONTEXT_TAKEOVER => {
|
|
||||||
// Invalid offer with multiple params with same name is declined.
|
|
||||||
if seen_client_no_context_takeover {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
seen_client_no_context_takeover = true;
|
|
||||||
config.client_no_context_takeover = true;
|
|
||||||
agreed.push(HeaderValue::from_static(CLIENT_NO_CONTEXT_TAKEOVER));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Max window bits are not supported at the moment.
|
|
||||||
SERVER_MAX_WINDOW_BITS => {
|
|
||||||
// Decline offer with invalid parameter value.
|
|
||||||
// `server_max_window_bits` requires a value in range [8, 15].
|
|
||||||
if let Some(bits) = val {
|
|
||||||
if !is_valid_max_window_bits(bits) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A server declines an extension negotiation offer with this parameter
|
|
||||||
// if the server doesn't support it.
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not supported, but server may ignore and accept the offer.
|
|
||||||
CLIENT_MAX_WINDOW_BITS => {
|
|
||||||
// Decline offer with invalid parameter value.
|
|
||||||
// `client_max_window_bits` requires a value in range [8, 15] or no value.
|
|
||||||
if let Some(bits) = val {
|
|
||||||
if !is_valid_max_window_bits(bits) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invalid offer with multiple params with same name is declined.
|
|
||||||
if seen_client_max_window_bits {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
seen_client_max_window_bits = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Offer with unknown parameter MUST be declined.
|
|
||||||
_ => {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Some((to_header_value(&agreed), DeflateContext::new(Role::Server, config)))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn accept_response<'a>(
|
|
||||||
&'a self,
|
|
||||||
agreed: impl Iterator<Item = (&'a str, Option<&'a str>)>,
|
|
||||||
) -> Result<DeflateContext, DeflateError> {
|
|
||||||
let mut config = DeflateConfig {
|
|
||||||
compression: self.compression,
|
|
||||||
// If this was hinted in the offer, the client won't use context takeover
|
|
||||||
// even if the response doesn't include it.
|
|
||||||
// See `generate_offer`.
|
|
||||||
client_no_context_takeover: self.client_no_context_takeover,
|
|
||||||
..DeflateConfig::default()
|
|
||||||
};
|
|
||||||
let mut seen_server_no_context_takeover = false;
|
|
||||||
let mut seen_client_no_context_takeover = false;
|
|
||||||
// A client MUST _Fail the WebSocket Connection_ if the peer server
|
|
||||||
// accepted an extension negotiation offer for this extension with an
|
|
||||||
// extension negotiation response meeting any of the following
|
|
||||||
// conditions:
|
|
||||||
// 1. The negotiation response contains an extension parameter not defined for use in a response.
|
|
||||||
// 2. The negotiation response contains an extension parameter with an invalid value.
|
|
||||||
// 3. The negotiation response contains multiple extension parameters with the same name.
|
|
||||||
// 4. The client does not support the configuration that the response represents.
|
|
||||||
for (key, val) in agreed {
|
|
||||||
match key {
|
|
||||||
SERVER_NO_CONTEXT_TAKEOVER => {
|
|
||||||
// Fail the connection when the response contains multiple parameters with the same name.
|
|
||||||
if seen_server_no_context_takeover {
|
|
||||||
return Err(DeflateError::Negotiation(
|
|
||||||
NegotiationError::DuplicateParameter(key.to_owned()),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
seen_server_no_context_takeover = true;
|
|
||||||
// A server MAY include the "server_no_context_takeover" extension
|
|
||||||
// parameter in an extension negotiation response even if the extension
|
|
||||||
// negotiation offer being accepted by the extension negotiation
|
|
||||||
// response didn't include the "server_no_context_takeover" extension
|
|
||||||
// parameter.
|
|
||||||
config.server_no_context_takeover = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
CLIENT_NO_CONTEXT_TAKEOVER => {
|
|
||||||
// Fail the connection when the response contains multiple parameters with the same name.
|
|
||||||
if seen_client_no_context_takeover {
|
|
||||||
return Err(DeflateError::Negotiation(
|
|
||||||
NegotiationError::DuplicateParameter(key.to_owned()),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
seen_client_no_context_takeover = true;
|
|
||||||
// The server may include this parameter in the response and the client MUST support it.
|
|
||||||
config.client_no_context_takeover = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
SERVER_MAX_WINDOW_BITS => {
|
|
||||||
// Fail the connection when the response contains a parameter with invalid value.
|
|
||||||
if let Some(bits) = val {
|
|
||||||
if !is_valid_max_window_bits(bits) {
|
|
||||||
return Err(DeflateError::Negotiation(
|
|
||||||
NegotiationError::InvalidServerMaxWindowBitsValue(bits.to_owned()),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err(DeflateError::Negotiation(
|
|
||||||
NegotiationError::MissingServerMaxWindowBitsValue,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// A server may include the "server_max_window_bits" extension parameter
|
|
||||||
// in an extension negotiation response even if the extension
|
|
||||||
// negotiation offer being accepted by the response didn't include the
|
|
||||||
// "server_max_window_bits" extension parameter.
|
|
||||||
//
|
|
||||||
// However, but we need to fail the connection because we don't support it (condition 4).
|
|
||||||
return Err(DeflateError::Negotiation(
|
|
||||||
NegotiationError::ServerMaxWindowBitsNotSupported,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
CLIENT_MAX_WINDOW_BITS => {
|
|
||||||
// Fail the connection when the response contains a parameter with invalid value.
|
|
||||||
if let Some(bits) = val {
|
|
||||||
if !is_valid_max_window_bits(bits) {
|
|
||||||
return Err(DeflateError::Negotiation(
|
|
||||||
NegotiationError::InvalidClientMaxWindowBitsValue(bits.to_owned()),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fail the connection because the parameter is invalid when the client didn't offer.
|
|
||||||
//
|
|
||||||
// If a received extension negotiation offer doesn't have the
|
|
||||||
// "client_max_window_bits" extension parameter, the corresponding
|
|
||||||
// extension negotiation response to the offer MUST NOT include the
|
|
||||||
// "client_max_window_bits" extension parameter.
|
|
||||||
return Err(DeflateError::Negotiation(
|
|
||||||
NegotiationError::UnexpectedClientMaxWindowBits,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Response with unknown parameter MUST fail the WebSocket connection.
|
|
||||||
_ => {
|
|
||||||
return Err(DeflateError::Negotiation(NegotiationError::UnknownParameter(
|
|
||||||
key.to_owned(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(DeflateContext::new(Role::Client, config))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A valid `client_max_window_bits` is no value or an integer in range `[8, 15]` without leading zeros.
|
|
||||||
// A valid `server_max_window_bits` is an integer in range `[8, 15]` without leading zeros.
|
|
||||||
fn is_valid_max_window_bits(bits: &str) -> bool {
|
|
||||||
// Note that values from `headers::SecWebSocketExtensions` is unquoted.
|
|
||||||
matches!(bits, "8" | "9" | "10" | "11" | "12" | "13" | "14" | "15")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::is_valid_max_window_bits;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn valid_max_window_bits() {
|
|
||||||
for bits in 8..=15 {
|
|
||||||
assert!(is_valid_max_window_bits(&bits.to_string()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn invalid_max_window_bits() {
|
|
||||||
assert!(!is_valid_max_window_bits(""));
|
|
||||||
assert!(!is_valid_max_window_bits("0"));
|
|
||||||
assert!(!is_valid_max_window_bits("08"));
|
|
||||||
assert!(!is_valid_max_window_bits("+8"));
|
|
||||||
assert!(!is_valid_max_window_bits("-8"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
/// Manages per message compression using DEFLATE.
|
|
||||||
pub struct DeflateContext {
|
|
||||||
role: Role,
|
|
||||||
config: DeflateConfig,
|
|
||||||
compressor: Compress,
|
|
||||||
decompressor: Decompress,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DeflateContext {
|
|
||||||
fn new(role: Role, config: DeflateConfig) -> Self {
|
|
||||||
DeflateContext {
|
|
||||||
role,
|
|
||||||
config,
|
|
||||||
compressor: Compress::new(config.compression, false),
|
|
||||||
decompressor: Decompress::new(false),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn own_context_takeover(&self) -> bool {
|
|
||||||
match self.role {
|
|
||||||
Role::Server => !self.config.server_no_context_takeover,
|
|
||||||
Role::Client => !self.config.client_no_context_takeover,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn peer_context_takeover(&self) -> bool {
|
|
||||||
match self.role {
|
|
||||||
Role::Server => !self.config.client_no_context_takeover,
|
|
||||||
Role::Client => !self.config.server_no_context_takeover,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compress the data of message.
|
|
||||||
pub(crate) fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, DeflateError> {
|
|
||||||
// https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1
|
|
||||||
// 1. Compress all the octets of the payload of the message using DEFLATE.
|
|
||||||
let mut output = Vec::with_capacity(data.len());
|
|
||||||
let before_in = self.compressor.total_in() as usize;
|
|
||||||
while (self.compressor.total_in() as usize) - before_in < data.len() {
|
|
||||||
let offset = (self.compressor.total_in() as usize) - before_in;
|
|
||||||
match self
|
|
||||||
.compressor
|
|
||||||
.compress_vec(&data[offset..], &mut output, FlushCompress::None)
|
|
||||||
.map_err(|e| DeflateError::Compress(e.into()))?
|
|
||||||
{
|
|
||||||
Status::Ok => continue,
|
|
||||||
Status::BufError => output.reserve(4096),
|
|
||||||
Status::StreamEnd => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 2. If the resulting data does not end with an empty DEFLATE block
|
|
||||||
// with no compression (the "BTYPE" bits are set to 00), append an
|
|
||||||
// empty DEFLATE block with no compression to the tail end.
|
|
||||||
while !output.ends_with(&TRAILER) {
|
|
||||||
output.reserve(5);
|
|
||||||
match self
|
|
||||||
.compressor
|
|
||||||
.compress_vec(&[], &mut output, FlushCompress::Sync)
|
|
||||||
.map_err(|e| DeflateError::Compress(e.into()))?
|
|
||||||
{
|
|
||||||
Status::Ok | Status::BufError => continue,
|
|
||||||
Status::StreamEnd => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 3. Remove 4 octets (that are 0x00 0x00 0xff 0xff) from the tail end.
|
|
||||||
// After this step, the last octet of the compressed data contains
|
|
||||||
// (possibly part of) the DEFLATE header bits with the "BTYPE" bits
|
|
||||||
// set to 00.
|
|
||||||
output.truncate(output.len() - 4);
|
|
||||||
|
|
||||||
if !self.own_context_takeover() {
|
|
||||||
self.compressor.reset();
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn decompress(
|
|
||||||
&mut self,
|
|
||||||
mut data: Vec<u8>,
|
|
||||||
is_final: bool,
|
|
||||||
) -> Result<Vec<u8>, DeflateError> {
|
|
||||||
if is_final {
|
|
||||||
data.extend_from_slice(&TRAILER);
|
|
||||||
}
|
|
||||||
|
|
||||||
let before_in = self.decompressor.total_in() as usize;
|
|
||||||
let mut output = Vec::with_capacity(2 * data.len());
|
|
||||||
loop {
|
|
||||||
let offset = (self.decompressor.total_in() as usize) - before_in;
|
|
||||||
match self
|
|
||||||
.decompressor
|
|
||||||
.decompress_vec(&data[offset..], &mut output, FlushDecompress::None)
|
|
||||||
.map_err(|e| DeflateError::Decompress(e.into()))?
|
|
||||||
{
|
|
||||||
Status::Ok => output.reserve(2 * output.len()),
|
|
||||||
Status::BufError | Status::StreamEnd => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_final && !self.peer_context_takeover() {
|
|
||||||
self.decompressor.reset(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_header_value(params: &[HeaderValue]) -> WebsocketExtension {
|
|
||||||
let mut buf = BytesMut::from(PER_MESSAGE_DEFLATE.as_bytes());
|
|
||||||
for param in params {
|
|
||||||
buf.extend_from_slice(b"; ");
|
|
||||||
buf.extend_from_slice(param.as_bytes());
|
|
||||||
}
|
|
||||||
let header = HeaderValue::from_maybe_shared(buf.freeze())
|
|
||||||
.expect("semicolon separated HeaderValue is valid");
|
|
||||||
WebsocketExtension::try_from(header).expect("valid extension")
|
|
||||||
}
|
|
|
@ -1,4 +0,0 @@
|
||||||
//! [Per-Message Compression Extensions][rfc7692]
|
|
||||||
//!
|
|
||||||
//! [rfc7692]: https://tools.ietf.org/html/rfc7692
|
|
||||||
pub mod deflate;
|
|
|
@ -1,18 +0,0 @@
|
||||||
//! WebSocket extensions.
|
|
||||||
// Only `permessage-deflate` is supported at the moment.
|
|
||||||
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
mod compression;
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
use compression::deflate::DeflateContext;
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
pub use compression::deflate::{DeflateConfig, DeflateError};
|
|
||||||
|
|
||||||
/// Container for configured extensions.
|
|
||||||
#[derive(Debug, Default)]
|
|
||||||
#[allow(missing_copy_implementations)]
|
|
||||||
pub struct Extensions {
|
|
||||||
// Per-Message Compression. Only `permessage-deflate` is supported.
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
pub(crate) compression: Option<DeflateContext>,
|
|
||||||
}
|
|
|
@ -5,7 +5,6 @@ use std::{
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
};
|
};
|
||||||
|
|
||||||
use headers::{HeaderMapExt, SecWebsocketExtensions};
|
|
||||||
use http::{
|
use http::{
|
||||||
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||||
};
|
};
|
||||||
|
@ -20,7 +19,6 @@ use super::{
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
error::{Error, ProtocolError, Result, UrlError},
|
error::{Error, ProtocolError, Result, UrlError},
|
||||||
extensions::Extensions,
|
|
||||||
protocol::{Role, WebSocket, WebSocketConfig},
|
protocol::{Role, WebSocket, WebSocketConfig},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -58,7 +56,7 @@ impl<S: Read + Write> ClientHandshake<S> {
|
||||||
|
|
||||||
// Convert and verify the `http::Request` and turn it into the request as per RFC.
|
// Convert and verify the `http::Request` and turn it into the request as per RFC.
|
||||||
// Also extract the key from it (it must be present in a correct request).
|
// Also extract the key from it (it must be present in a correct request).
|
||||||
let (request, key) = generate_request(request, &config)?;
|
let (request, key) = generate_request(request)?;
|
||||||
|
|
||||||
let machine = HandshakeMachine::start_write(stream, request);
|
let machine = HandshakeMachine::start_write(stream, request);
|
||||||
|
|
||||||
|
@ -85,24 +83,18 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
||||||
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
|
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
|
||||||
}
|
}
|
||||||
StageResult::DoneReading { stream, result, tail } => {
|
StageResult::DoneReading { stream, result, tail } => {
|
||||||
let (result, extensions) =
|
let result = match self.verify_data.verify_response(result) {
|
||||||
match self.verify_data.verify_response(result, &self.config) {
|
Ok(r) => r,
|
||||||
Ok(r) => r,
|
Err(Error::Http(mut e)) => {
|
||||||
Err(Error::Http(mut e)) => {
|
*e.body_mut() = Some(tail);
|
||||||
*e.body_mut() = Some(tail);
|
return Err(Error::Http(e))
|
||||||
return Err(Error::Http(e));
|
},
|
||||||
}
|
Err(e) => return Err(e),
|
||||||
Err(e) => return Err(e),
|
};
|
||||||
};
|
|
||||||
|
|
||||||
debug!("Client handshake done.");
|
debug!("Client handshake done.");
|
||||||
let websocket = WebSocket::from_partially_read_with_extensions(
|
let websocket =
|
||||||
stream,
|
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
|
||||||
tail,
|
|
||||||
Role::Client,
|
|
||||||
self.config,
|
|
||||||
extensions,
|
|
||||||
);
|
|
||||||
ProcessingResult::Done((websocket, result))
|
ProcessingResult::Done((websocket, result))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -110,10 +102,7 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
|
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
|
||||||
pub fn generate_request(
|
pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
||||||
mut request: Request,
|
|
||||||
config: &Option<WebSocketConfig>,
|
|
||||||
) -> Result<(Vec<u8>, String)> {
|
|
||||||
let mut req = Vec::new();
|
let mut req = Vec::new();
|
||||||
write!(
|
write!(
|
||||||
req,
|
req,
|
||||||
|
@ -184,9 +173,6 @@ pub fn generate_request(
|
||||||
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
|
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(offers) = config.and_then(|c| c.generate_offers()) {
|
|
||||||
writeln!(req, "Sec-WebSocket-Extensions: {}\r", offers.to_value().to_str()?).unwrap();
|
|
||||||
}
|
|
||||||
writeln!(req, "\r").unwrap();
|
writeln!(req, "\r").unwrap();
|
||||||
trace!("Request: {:?}", String::from_utf8_lossy(&req));
|
trace!("Request: {:?}", String::from_utf8_lossy(&req));
|
||||||
Ok((req, key))
|
Ok((req, key))
|
||||||
|
@ -200,11 +186,7 @@ struct VerifyData {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VerifyData {
|
impl VerifyData {
|
||||||
pub fn verify_response(
|
pub fn verify_response(&self, response: Response) -> Result<Response> {
|
||||||
&self,
|
|
||||||
response: Response,
|
|
||||||
_config: &Option<WebSocketConfig>,
|
|
||||||
) -> Result<(Response, Option<Extensions>)> {
|
|
||||||
// 1. If the status code received from the server is not 101, the
|
// 1. If the status code received from the server is not 101, the
|
||||||
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
|
// client handles the response per HTTP [RFC2616] procedures. (RFC 6455)
|
||||||
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
|
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
|
||||||
|
@ -249,14 +231,7 @@ impl VerifyData {
|
||||||
// that was not present in the client's handshake (the server has
|
// that was not present in the client's handshake (the server has
|
||||||
// indicated an extension not requested by the client), the client
|
// indicated an extension not requested by the client), the client
|
||||||
// MUST _Fail the WebSocket Connection_. (RFC 6455)
|
// MUST _Fail the WebSocket Connection_. (RFC 6455)
|
||||||
let extensions = if let Some(agreed) = headers
|
// TODO
|
||||||
.typed_try_get::<SecWebsocketExtensions>()
|
|
||||||
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
|
|
||||||
{
|
|
||||||
verify_extensions(&agreed, _config)?
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
|
// 6. If the response includes a |Sec-WebSocket-Protocol| header field
|
||||||
// and this header field indicates the use of a subprotocol that was
|
// and this header field indicates the use of a subprotocol that was
|
||||||
|
@ -265,49 +240,10 @@ impl VerifyData {
|
||||||
// the WebSocket Connection_. (RFC 6455)
|
// the WebSocket Connection_. (RFC 6455)
|
||||||
// TODO
|
// TODO
|
||||||
|
|
||||||
Ok((response, extensions))
|
Ok(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn verify_extensions(
|
|
||||||
agreed_extensions: &headers::SecWebsocketExtensions,
|
|
||||||
_config: &Option<WebSocketConfig>,
|
|
||||||
) -> Result<Option<Extensions>> {
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
{
|
|
||||||
if let Some(compression) = _config.and_then(|c| c.compression) {
|
|
||||||
let mut extensions = None;
|
|
||||||
for extension in agreed_extensions.iter() {
|
|
||||||
// > If a server gives an invalid response, such as accepting a PMCE that the client did not offer,
|
|
||||||
// > the client MUST _Fail the WebSocket Connection_.
|
|
||||||
if extension.name() != compression.name() {
|
|
||||||
return Err(Error::Protocol(ProtocolError::InvalidExtension(
|
|
||||||
extension.name().to_string(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Already had PMCE configured
|
|
||||||
if extensions.is_some() {
|
|
||||||
return Err(Error::Protocol(ProtocolError::ExtensionConflict(
|
|
||||||
extension.name().to_string(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
extensions = Some(Extensions {
|
|
||||||
compression: Some(compression.accept_response(extension.params())?),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return Ok(extensions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(extension) = agreed_extensions.iter().next() {
|
|
||||||
// The client didn't request anything, but got something
|
|
||||||
return Err(Error::Protocol(ProtocolError::InvalidExtension(extension.name().to_string())));
|
|
||||||
}
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryParse for Response {
|
impl TryParse for Response {
|
||||||
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||||
|
@ -350,8 +286,6 @@ pub fn generate_key() -> String {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{super::machine::TryParse, generate_key, generate_request, Response};
|
use super::{super::machine::TryParse, generate_key, generate_request, Response};
|
||||||
use crate::client::IntoClientRequest;
|
use crate::client::IntoClientRequest;
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
use crate::{extensions::DeflateConfig, protocol::WebSocketConfig};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn random_keys() {
|
fn random_keys() {
|
||||||
|
@ -388,7 +322,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn request_formatting() {
|
fn request_formatting() {
|
||||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||||
let (request, key) = generate_request(request, &None).unwrap();
|
let (request, key) = generate_request(request).unwrap();
|
||||||
let correct = construct_expected("localhost", &key);
|
let correct = construct_expected("localhost", &key);
|
||||||
assert_eq!(&request[..], &correct[..]);
|
assert_eq!(&request[..], &correct[..]);
|
||||||
}
|
}
|
||||||
|
@ -396,7 +330,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn request_formatting_with_host() {
|
fn request_formatting_with_host() {
|
||||||
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
|
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||||
let (request, key) = generate_request(request, &None).unwrap();
|
let (request, key) = generate_request(request).unwrap();
|
||||||
let correct = construct_expected("localhost:9001", &key);
|
let correct = construct_expected("localhost:9001", &key);
|
||||||
assert_eq!(&request[..], &correct[..]);
|
assert_eq!(&request[..], &correct[..]);
|
||||||
}
|
}
|
||||||
|
@ -404,40 +338,11 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn request_formatting_with_at() {
|
fn request_formatting_with_at() {
|
||||||
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
|
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||||
let (request, key) = generate_request(request, &None).unwrap();
|
let (request, key) = generate_request(request).unwrap();
|
||||||
let correct = construct_expected("localhost:9001", &key);
|
let correct = construct_expected("localhost:9001", &key);
|
||||||
assert_eq!(&request[..], &correct[..]);
|
assert_eq!(&request[..], &correct[..]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
#[test]
|
|
||||||
fn request_with_compression() {
|
|
||||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
|
||||||
let (request, key) = generate_request(
|
|
||||||
request,
|
|
||||||
&Some(WebSocketConfig {
|
|
||||||
compression: Some(DeflateConfig::default()),
|
|
||||||
..WebSocketConfig::default()
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let correct = format!(
|
|
||||||
"\
|
|
||||||
GET /getCaseCount HTTP/1.1\r\n\
|
|
||||||
Host: {host}\r\n\
|
|
||||||
Connection: Upgrade\r\n\
|
|
||||||
Upgrade: websocket\r\n\
|
|
||||||
Sec-WebSocket-Version: 13\r\n\
|
|
||||||
Sec-WebSocket-Key: {key}\r\n\
|
|
||||||
Sec-WebSocket-Extensions: permessage-deflate\r\n\
|
|
||||||
\r\n",
|
|
||||||
host = "localhost",
|
|
||||||
key = key
|
|
||||||
)
|
|
||||||
.into_bytes();
|
|
||||||
assert_eq!(&request[..], &correct[..]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_parsing() {
|
fn response_parsing() {
|
||||||
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
|
const DATA: &[u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
|
||||||
|
@ -449,6 +354,6 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn invalid_custom_request() {
|
fn invalid_custom_request() {
|
||||||
let request = http::Request::builder().method("GET").body(()).unwrap();
|
let request = http::Request::builder().method("GET").body(()).unwrap();
|
||||||
assert!(generate_request(request, &None).is_err());
|
assert!(generate_request(request).is_err());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ use std::{
|
||||||
result::Result as StdResult,
|
result::Result as StdResult,
|
||||||
};
|
};
|
||||||
|
|
||||||
use headers::{HeaderMapExt, SecWebsocketExtensions};
|
|
||||||
use http::{
|
use http::{
|
||||||
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||||
};
|
};
|
||||||
|
@ -21,7 +20,6 @@ use super::{
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
error::{Error, ProtocolError, Result},
|
error::{Error, ProtocolError, Result},
|
||||||
extensions::Extensions,
|
|
||||||
protocol::{Role, WebSocket, WebSocketConfig},
|
protocol::{Role, WebSocket, WebSocketConfig},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -204,8 +202,6 @@ pub struct ServerHandshake<S, C> {
|
||||||
config: Option<WebSocketConfig>,
|
config: Option<WebSocketConfig>,
|
||||||
/// Error code/flag. If set, an error will be returned after sending response to the client.
|
/// Error code/flag. If set, an error will be returned after sending response to the client.
|
||||||
error_response: Option<ErrorResponse>,
|
error_response: Option<ErrorResponse>,
|
||||||
// Negotiated extension context for server.
|
|
||||||
extensions: Option<Extensions>,
|
|
||||||
/// Internal stream type.
|
/// Internal stream type.
|
||||||
_marker: PhantomData<S>,
|
_marker: PhantomData<S>,
|
||||||
}
|
}
|
||||||
|
@ -223,7 +219,6 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
|
||||||
callback: Some(callback),
|
callback: Some(callback),
|
||||||
config,
|
config,
|
||||||
error_response: None,
|
error_response: None,
|
||||||
extensions: None,
|
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -245,19 +240,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
||||||
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
|
return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut response = create_response(&result)?;
|
let response = create_response(&result)?;
|
||||||
if let Some(config) = &self.config {
|
|
||||||
if let Some((agreed, extensions)) = result
|
|
||||||
.headers()
|
|
||||||
.typed_try_get::<SecWebsocketExtensions>()
|
|
||||||
.map_err(|_| Error::Protocol(ProtocolError::InvalidExtensionsHeader))?
|
|
||||||
.and_then(|values| config.accept_offers(&values))
|
|
||||||
{
|
|
||||||
response.headers_mut().typed_insert(agreed);
|
|
||||||
self.extensions = Some(extensions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let callback_result = if let Some(callback) = self.callback.take() {
|
let callback_result = if let Some(callback) = self.callback.take() {
|
||||||
callback.on_request(&result, response)
|
callback.on_request(&result, response)
|
||||||
} else {
|
} else {
|
||||||
|
@ -300,12 +283,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
||||||
return Err(Error::Http(http::Response::from_parts(parts, body)));
|
return Err(Error::Http(http::Response::from_parts(parts, body)));
|
||||||
} else {
|
} else {
|
||||||
debug!("Server handshake done.");
|
debug!("Server handshake done.");
|
||||||
let websocket = WebSocket::from_raw_socket_with_extensions(
|
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
|
||||||
stream,
|
|
||||||
Role::Server,
|
|
||||||
self.config,
|
|
||||||
self.extensions.take(),
|
|
||||||
);
|
|
||||||
ProcessingResult::Done(websocket)
|
ProcessingResult::Done(websocket)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@ pub mod buffer;
|
||||||
#[cfg(feature = "handshake")]
|
#[cfg(feature = "handshake")]
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod extensions;
|
|
||||||
#[cfg(feature = "handshake")]
|
#[cfg(feature = "handshake")]
|
||||||
pub mod handshake;
|
pub mod handshake;
|
||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
|
|
|
@ -311,18 +311,6 @@ impl Frame {
|
||||||
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
|
Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new compressed data frame.
|
|
||||||
#[inline]
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
pub(crate) fn compressed_message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
|
|
||||||
debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
|
|
||||||
|
|
||||||
Frame {
|
|
||||||
header: FrameHeader { is_final, opcode, rsv1: true, ..FrameHeader::default() },
|
|
||||||
payload: data,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new Pong control frame.
|
/// Create a new Pong control frame.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn pong(data: Vec<u8>) -> Frame {
|
pub fn pong(data: Vec<u8>) -> Frame {
|
||||||
|
|
|
@ -84,8 +84,6 @@ use self::string_collect::StringCollector;
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct IncompleteMessage {
|
pub struct IncompleteMessage {
|
||||||
collector: IncompleteMessageCollector,
|
collector: IncompleteMessageCollector,
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
compressed: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -96,7 +94,6 @@ enum IncompleteMessageCollector {
|
||||||
|
|
||||||
impl IncompleteMessage {
|
impl IncompleteMessage {
|
||||||
/// Create new.
|
/// Create new.
|
||||||
#[cfg(not(feature = "deflate"))]
|
|
||||||
pub fn new(message_type: IncompleteMessageType) -> Self {
|
pub fn new(message_type: IncompleteMessageType) -> Self {
|
||||||
IncompleteMessage {
|
IncompleteMessage {
|
||||||
collector: match message_type {
|
collector: match message_type {
|
||||||
|
@ -108,25 +105,6 @@ impl IncompleteMessage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new.
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
pub fn new(message_type: IncompleteMessageType, compressed: bool) -> Self {
|
|
||||||
IncompleteMessage {
|
|
||||||
collector: match message_type {
|
|
||||||
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
|
|
||||||
IncompleteMessageType::Text => {
|
|
||||||
IncompleteMessageCollector::Text(StringCollector::new())
|
|
||||||
}
|
|
||||||
},
|
|
||||||
compressed,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
pub fn compressed(&self) -> bool {
|
|
||||||
self.compressed
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the current filled size of the buffer.
|
/// Get the current filled size of the buffer.
|
||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
match self.collector {
|
match self.collector {
|
||||||
|
|
|
@ -22,7 +22,6 @@ use self::{
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
error::{Error, ProtocolError, Result},
|
error::{Error, ProtocolError, Result},
|
||||||
extensions::Extensions,
|
|
||||||
util::NonBlockingResult,
|
util::NonBlockingResult,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -57,9 +56,6 @@ pub struct WebSocketConfig {
|
||||||
/// some popular libraries that are sending unmasked frames, ignoring the RFC.
|
/// some popular libraries that are sending unmasked frames, ignoring the RFC.
|
||||||
/// By default this option is set to `false`, i.e. according to RFC 6455.
|
/// By default this option is set to `false`, i.e. according to RFC 6455.
|
||||||
pub accept_unmasked_frames: bool,
|
pub accept_unmasked_frames: bool,
|
||||||
/// Optional configuration for Per-Message Compression Extension.
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
pub compression: Option<crate::extensions::DeflateConfig>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for WebSocketConfig {
|
impl Default for WebSocketConfig {
|
||||||
|
@ -69,64 +65,6 @@ impl Default for WebSocketConfig {
|
||||||
max_message_size: Some(64 << 20),
|
max_message_size: Some(64 << 20),
|
||||||
max_frame_size: Some(16 << 20),
|
max_frame_size: Some(16 << 20),
|
||||||
accept_unmasked_frames: false,
|
accept_unmasked_frames: false,
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
compression: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WebSocketConfig {
|
|
||||||
// Generate extension negotiation offers for configured extensions.
|
|
||||||
// Only `permessage-deflate` is supported at the moment.
|
|
||||||
pub(crate) fn generate_offers(&self) -> Option<headers::SecWebsocketExtensions> {
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
{
|
|
||||||
let mut offers = Vec::new();
|
|
||||||
if let Some(compression) = self.compression.map(|c| c.generate_offer()) {
|
|
||||||
offers.push(compression);
|
|
||||||
}
|
|
||||||
if offers.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(headers::SecWebsocketExtensions::new(offers))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[cfg(not(feature = "deflate"))]
|
|
||||||
{
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This can be used with `WebSocket::from_raw_socket_with_extensions` for integration.
|
|
||||||
/// Returns negotiation response based on offers and `Extensions` to manage extensions.
|
|
||||||
pub fn accept_offers(
|
|
||||||
&self,
|
|
||||||
_offers: &headers::SecWebsocketExtensions,
|
|
||||||
) -> Option<(headers::SecWebsocketExtensions, Extensions)> {
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
{
|
|
||||||
// To support more extensions, store extension context in `Extensions` and
|
|
||||||
// concatenate negotiation responses from each extension.
|
|
||||||
let mut agreed_extensions = Vec::new();
|
|
||||||
let mut extensions = Extensions::default();
|
|
||||||
|
|
||||||
if let Some(compression) = &self.compression {
|
|
||||||
if let Some((agreed, compression)) = compression.accept_offer(_offers) {
|
|
||||||
agreed_extensions.push(agreed);
|
|
||||||
extensions.compression = Some(compression);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if agreed_extensions.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some((headers::SecWebsocketExtensions::new(agreed_extensions), extensions))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "deflate"))]
|
|
||||||
{
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,18 +91,6 @@ impl<Stream> WebSocket<Stream> {
|
||||||
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
|
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
|
||||||
pub fn from_raw_socket_with_extensions(
|
|
||||||
stream: Stream,
|
|
||||||
role: Role,
|
|
||||||
config: Option<WebSocketConfig>,
|
|
||||||
extensions: Option<Extensions>,
|
|
||||||
) -> Self {
|
|
||||||
let mut context = WebSocketContext::new(role, config);
|
|
||||||
context.extensions = extensions;
|
|
||||||
WebSocket { socket: stream, context }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert a raw socket into a WebSocket without performing a handshake.
|
/// Convert a raw socket into a WebSocket without performing a handshake.
|
||||||
///
|
///
|
||||||
/// Call this function if you're using Tungstenite as a part of a web framework
|
/// Call this function if you're using Tungstenite as a part of a web framework
|
||||||
|
@ -182,21 +108,6 @@ impl<Stream> WebSocket<Stream> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn from_partially_read_with_extensions(
|
|
||||||
stream: Stream,
|
|
||||||
part: Vec<u8>,
|
|
||||||
role: Role,
|
|
||||||
config: Option<WebSocketConfig>,
|
|
||||||
extensions: Option<Extensions>,
|
|
||||||
) -> Self {
|
|
||||||
WebSocket {
|
|
||||||
socket: stream,
|
|
||||||
context: WebSocketContext::from_partially_read_with_extensions(
|
|
||||||
part, role, config, extensions,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a shared reference to the inner stream.
|
/// Returns a shared reference to the inner stream.
|
||||||
pub fn get_ref(&self) -> &Stream {
|
pub fn get_ref(&self) -> &Stream {
|
||||||
&self.socket
|
&self.socket
|
||||||
|
@ -330,8 +241,6 @@ pub struct WebSocketContext {
|
||||||
pong: Option<Frame>,
|
pong: Option<Frame>,
|
||||||
/// The configuration for the websocket session.
|
/// The configuration for the websocket session.
|
||||||
config: WebSocketConfig,
|
config: WebSocketConfig,
|
||||||
// Container for extensions.
|
|
||||||
pub(crate) extensions: Option<Extensions>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WebSocketContext {
|
impl WebSocketContext {
|
||||||
|
@ -345,7 +254,6 @@ impl WebSocketContext {
|
||||||
send_queue: VecDeque::new(),
|
send_queue: VecDeque::new(),
|
||||||
pong: None,
|
pong: None,
|
||||||
config: config.unwrap_or_default(),
|
config: config.unwrap_or_default(),
|
||||||
extensions: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,19 +265,6 @@ impl WebSocketContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn from_partially_read_with_extensions(
|
|
||||||
part: Vec<u8>,
|
|
||||||
role: Role,
|
|
||||||
config: Option<WebSocketConfig>,
|
|
||||||
extensions: Option<Extensions>,
|
|
||||||
) -> Self {
|
|
||||||
WebSocketContext {
|
|
||||||
frame: FrameCodec::from_partially_read(part),
|
|
||||||
extensions,
|
|
||||||
..WebSocketContext::new(role, config)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Change the configuration.
|
/// Change the configuration.
|
||||||
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
|
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
|
||||||
set_func(&mut self.config)
|
set_func(&mut self.config)
|
||||||
|
@ -453,8 +348,8 @@ impl WebSocketContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
let frame = match message {
|
let frame = match message {
|
||||||
Message::Text(data) => self.prepare_data_frame(data.into(), OpData::Text)?,
|
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
|
||||||
Message::Binary(data) => self.prepare_data_frame(data, OpData::Binary)?,
|
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
|
||||||
Message::Ping(data) => Frame::ping(data),
|
Message::Ping(data) => Frame::ping(data),
|
||||||
Message::Pong(data) => {
|
Message::Pong(data) => {
|
||||||
self.pong = Some(Frame::pong(data));
|
self.pong = Some(Frame::pong(data));
|
||||||
|
@ -468,17 +363,6 @@ impl WebSocketContext {
|
||||||
self.write_pending(stream)
|
self.write_pending(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_data_frame(&mut self, data: Vec<u8>, opdata: OpData) -> Result<Frame> {
|
|
||||||
debug_assert!(matches!(opdata, OpData::Text | OpData::Binary), "Invalid data frame kind");
|
|
||||||
let opcode = OpCode::Data(opdata);
|
|
||||||
let is_final = true;
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
if let Some(pmce) = self.extensions.as_mut().and_then(|e| e.compression.as_mut()) {
|
|
||||||
return Ok(Frame::compressed_message(pmce.compress(&data)?, opcode, is_final));
|
|
||||||
}
|
|
||||||
Ok(Frame::message(data, opcode, is_final))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Flush the pending send queue.
|
/// Flush the pending send queue.
|
||||||
pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
|
pub fn write_pending<Stream>(&mut self, stream: &mut Stream) -> Result<()>
|
||||||
where
|
where
|
||||||
|
@ -555,14 +439,12 @@ impl WebSocketContext {
|
||||||
// the negotiated extensions defines the meaning of such a nonzero
|
// the negotiated extensions defines the meaning of such a nonzero
|
||||||
// value, the receiving endpoint MUST _Fail the WebSocket
|
// value, the receiving endpoint MUST _Fail the WebSocket
|
||||||
// Connection_.
|
// Connection_.
|
||||||
let is_compressed = {
|
{
|
||||||
let hdr = frame.header();
|
let hdr = frame.header();
|
||||||
if (hdr.rsv1 && !self.has_compression()) || hdr.rsv2 || hdr.rsv3 {
|
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
|
||||||
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
|
return Err(Error::Protocol(ProtocolError::NonZeroReservedBits));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
hdr.rsv1
|
|
||||||
};
|
|
||||||
|
|
||||||
match self.role {
|
match self.role {
|
||||||
Role::Server => {
|
Role::Server => {
|
||||||
|
@ -597,10 +479,6 @@ impl WebSocketContext {
|
||||||
_ if frame.payload().len() > 125 => {
|
_ if frame.payload().len() > 125 => {
|
||||||
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
|
Err(Error::Protocol(ProtocolError::ControlFrameTooBig))
|
||||||
}
|
}
|
||||||
// Control frames must not have compress bit.
|
|
||||||
_ if is_compressed => {
|
|
||||||
Err(Error::Protocol(ProtocolError::CompressedControlFrame))
|
|
||||||
}
|
|
||||||
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
|
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
|
||||||
OpCtl::Reserved(i) => {
|
OpCtl::Reserved(i) => {
|
||||||
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
|
Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i)))
|
||||||
|
@ -621,34 +499,39 @@ impl WebSocketContext {
|
||||||
let fin = frame.header().is_final;
|
let fin = frame.header().is_final;
|
||||||
match data {
|
match data {
|
||||||
OpData::Continue => {
|
OpData::Continue => {
|
||||||
if self.incomplete.is_some() && is_compressed {
|
if let Some(ref mut msg) = self.incomplete {
|
||||||
|
msg.extend(frame.into_data(), self.config.max_message_size)?;
|
||||||
|
} else {
|
||||||
return Err(Error::Protocol(
|
return Err(Error::Protocol(
|
||||||
ProtocolError::CompressedContinueFrame,
|
ProtocolError::UnexpectedContinueFrame,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
if fin {
|
||||||
let msg = self
|
Ok(Some(self.incomplete.take().unwrap().complete()?))
|
||||||
.incomplete
|
} else {
|
||||||
.take()
|
Ok(None)
|
||||||
.ok_or(Error::Protocol(ProtocolError::UnexpectedContinueFrame))?;
|
}
|
||||||
self.extend_incomplete(msg, frame.into_data(), fin)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c if self.incomplete.is_some() => {
|
c if self.incomplete.is_some() => {
|
||||||
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
|
Err(Error::Protocol(ProtocolError::ExpectedFragment(c)))
|
||||||
}
|
}
|
||||||
|
|
||||||
OpData::Text | OpData::Binary => {
|
OpData::Text | OpData::Binary => {
|
||||||
let message_type = match data {
|
let msg = {
|
||||||
OpData::Text => IncompleteMessageType::Text,
|
let message_type = match data {
|
||||||
OpData::Binary => IncompleteMessageType::Binary,
|
OpData::Text => IncompleteMessageType::Text,
|
||||||
_ => panic!("Bug: message is not text nor binary"),
|
OpData::Binary => IncompleteMessageType::Binary,
|
||||||
|
_ => panic!("Bug: message is not text nor binary"),
|
||||||
|
};
|
||||||
|
let mut m = IncompleteMessage::new(message_type);
|
||||||
|
m.extend(frame.into_data(), self.config.max_message_size)?;
|
||||||
|
m
|
||||||
};
|
};
|
||||||
#[cfg(feature = "deflate")]
|
if fin {
|
||||||
let msg = IncompleteMessage::new(message_type, is_compressed);
|
Ok(Some(msg.complete()?))
|
||||||
#[cfg(not(feature = "deflate"))]
|
} else {
|
||||||
let msg = IncompleteMessage::new(message_type);
|
self.incomplete = Some(msg);
|
||||||
self.extend_incomplete(msg, frame.into_data(), fin)
|
Ok(None)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
OpData::Reserved(i) => {
|
OpData::Reserved(i) => {
|
||||||
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
|
Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))
|
||||||
|
@ -667,32 +550,6 @@ impl WebSocketContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extend_incomplete(
|
|
||||||
&mut self,
|
|
||||||
mut msg: IncompleteMessage,
|
|
||||||
data: Vec<u8>,
|
|
||||||
is_final: bool,
|
|
||||||
) -> Result<Option<Message>> {
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
let data = if msg.compressed() {
|
|
||||||
// `msg.compressed()` is only true when compression is enabled so it's safe to unwrap
|
|
||||||
self.extensions
|
|
||||||
.as_mut()
|
|
||||||
.and_then(|x| x.compression.as_mut())
|
|
||||||
.unwrap()
|
|
||||||
.decompress(data, is_final)?
|
|
||||||
} else {
|
|
||||||
data
|
|
||||||
};
|
|
||||||
msg.extend(data, self.config.max_message_size)?;
|
|
||||||
if is_final {
|
|
||||||
Ok(Some(msg.complete()?))
|
|
||||||
} else {
|
|
||||||
self.incomplete = Some(msg);
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Received a close frame. Tells if we need to return a close frame to the user.
|
/// Received a close frame. Tells if we need to return a close frame to the user.
|
||||||
#[allow(clippy::option_option)]
|
#[allow(clippy::option_option)]
|
||||||
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
|
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
|
||||||
|
@ -748,17 +605,6 @@ impl WebSocketContext {
|
||||||
trace!("Sending frame: {:?}", frame);
|
trace!("Sending frame: {:?}", frame);
|
||||||
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
|
self.frame.write_frame(stream, frame).check_connection_reset(self.state)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn has_compression(&self) -> bool {
|
|
||||||
#[cfg(feature = "deflate")]
|
|
||||||
{
|
|
||||||
self.extensions.as_ref().and_then(|c| c.compression.as_ref()).is_some()
|
|
||||||
}
|
|
||||||
#[cfg(not(feature = "deflate"))]
|
|
||||||
{
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The current connection state.
|
/// The current connection state.
|
||||||
|
|
Loading…
Reference in New Issue