Compare commits

...

42 Commits

Author SHA1 Message Date
Conor ac3ce4d5a2
Merge pull request #1 from Its-Just-Nans/master
Change tests to merge https://github.com/snapview/tungstenite-rs/pull/363
2024-04-26 10:50:51 +10:00
n4n5 adbc70a6b4
fix test for merging 2024-03-08 15:37:30 +01:00
n4n5 b44721662d
Merge remote-tracking branch 'newone/master' 2024-03-08 15:35:34 +01:00
Félix Lescaudey de Maneville 0fa41973b4
Add builder for additional header values (#400)
* ADd builder for additional header values

* Update client.rs

* fix: docs

* feat: add test

* fix: typo

* add

---------

Co-authored-by: n4n5 <56606507+Its-Just-Nans@users.noreply.github.com>
Co-authored-by: n4n5 <its.just.n4n5@gmail.com>
2024-02-12 20:56:15 +01:00
Alex Butler 2ee05d1080 Update 0.21.0 changelog 2023-12-12 13:10:10 +01:00
Daniel Abramov 3df40fd0f1 Update CHANGELOG.md
Fixes #398, #389.
2023-12-11 12:47:40 +01:00
Sebastian Dröge e9604ac35b Update MSRV to 1.60 and check on the CI
The code does not actually compile anymore with 1.51.

Also only run `cargo check` with 1.60 as various optional features,
tests, benchmarks actually require newer Rust versions.
2023-12-08 11:59:01 +01:00
Alexey Galakhov 85463b264e Release version 0.21.0 2023-12-07 02:03:01 +01:00
Constantin Nickel bcd7f85e65 Update `rustls` to 0.22 2023-12-05 22:24:29 +01:00
Alex Butler 9f0af2a2e3 Test that no additional flushes are called after pong flush success 2023-12-05 22:23:31 +01:00
Alex Butler 2d5b3e18de Fix auto pong responses not flushing after block
Retry pong flushes on read.
Add read_usage_auto_pong_flush scenario test
2023-12-05 22:23:31 +01:00
Alex Butler a54623ccfe Remove proposed version from changelog 2023-12-02 20:30:48 +01:00
Alex Butler 866ce20dbe Update webpki-roots to 0.26 2023-12-02 20:30:48 +01:00
Alex Butler 0f6e6517e6 Fix FrameHeader::format write & other lints 2023-12-02 00:01:17 +01:00
Alexey Galakhov fc17f7341d
Merge pull request #373 from psychon/reduce-byteorder
Reduce use of byteorder crate
2023-11-17 19:38:27 +01:00
Daniel Abramov a43bb499df
Merge pull request #386 from snapview/dependabot/cargo/http-1.0
Update http requirement from 0.2 to 1.0
2023-11-17 18:54:04 +01:00
Alexey Galakhov 08cdd76dd6
Merge pull request #387 from nickelc/deps/socket2
Replace deprecated `net2` dev dependency with `socket2`
2023-11-17 02:12:31 +01:00
Constantin Nickel 8ca9b2314c Replace deprecated `net2` dev dependency with `socket2` 2023-11-16 11:33:58 +01:00
dependabot[bot] 2867907f15
Update http requirement from 0.2 to 1.0
Updates the requirements on [http](https://github.com/hyperium/http) to permit the latest version.
- [Release notes](https://github.com/hyperium/http/releases)
- [Changelog](https://github.com/hyperium/http/blob/master/CHANGELOG.md)
- [Commits](https://github.com/hyperium/http/compare/v0.2.0...v1.0.0)

---
updated-dependencies:
- dependency-name: http
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-11-16 08:04:57 +00:00
Alexey Galakhov a6cbed7bff
Merge pull request #385 from nickelc/deps/webpki-roots
Update `webpki-roots` to 0.25
2023-11-15 22:30:51 +01:00
Constantin Nickel 6c61d54ad2 Update `webpki-roots` to 0.25 2023-11-15 19:56:28 +01:00
Daniel Abramov 272d83c430 doc: clarify the meaning of config values 2023-10-29 14:09:32 +01:00
Daniel Abramov 3d9fd1e5cb
Merge pull request #383 from atouchet/cbar
Update Autobahn Test Suite page
2023-10-22 22:49:07 +02:00
Alex Touchet 614200a2fa
Update Autobahn Test Suite page 2023-10-16 19:36:40 -07:00
Daniel Abramov 8b3ecd3cc0 Update changelog 2023-09-23 14:10:53 +02:00
Daniel Abramov 722850d473 Fix incorrect metadata in Cargo.toml 2023-09-23 14:10:04 +02:00
Alexey Galakhov 219075edaa
Merge pull request #379 from snapview/CVE-2023-43669
Quick-and-dirty fix for CVE-2023-43669
2023-09-23 02:33:24 +02:00
Alexey Galakhov f0f1a06a50 Bump crate version 2023-09-23 02:16:16 +02:00
Alexey Galakhov 2e5029284b Add checking for header sanity
Co-authored-by: Daniel Abramov <inetcrack2@gmail.com>
2023-09-23 02:16:09 +02:00
Alexey Galakhov f916b332a9 Add `AttackAttempt` error variant 2023-09-22 17:48:31 +02:00
Uli Schlachter e4224ed85a Reduce use of byteorder crate
The byteorder dependency is only used in protocol::frame::frame. I
thought this dependency could easily be removed and set out to replace
the use of byteorder with equivalent std methods.

NetworkEndian is an alias for BigEndian. Converting a number like u32 to
bytes can be done via the std lib via .to_be_bytes(). The opposite
direction is from_by_bytes(). These simple things thus to not need
byteorder.

There is one place in the code where byteorder actually helps, thus this
dependency is not actually fully removed. ByteOrder::read_uint() allows
to read 1 to 8 bytes of data and returns the result as u64. Doing this
with the standard library basically requires re-implementing byteorder.
Thus, I did not do that.

Signed-off-by: Uli Schlachter <psychon@znc.in>
2023-08-18 10:35:33 +02:00
Daniel Abramov 53914c1180 Include examples so that `cargo publish` works 2023-07-22 15:36:37 +01:00
Daniel Abramov 5323559891 Bump version 2023-07-22 15:19:04 +01:00
Daniel Abramov 6e63b17b63 Update changelog 2023-07-22 15:18:14 +01:00
Daniel Abramov f2ed7aa826
Merge pull request #365 from snapview/dependabot/cargo/webpki-roots-0.24
Update webpki-roots requirement from 0.23 to 0.24
2023-07-22 16:10:14 +02:00
Daniel Abramov 8d8f0da204
Merge pull request #362 from alexheretic/config-asserts
Add assert panics for `WebSocketConfig`
2023-07-22 16:08:30 +02:00
Daniel Abramov dac07ea68b
Merge pull request #361 from alexheretic/docs++
Clarify `WebSocketConfig` docs
2023-07-22 16:07:39 +02:00
dependabot[bot] 40cd43c4f9
Update webpki-roots requirement from 0.23 to 0.24
Updates the requirements on [webpki-roots](https://github.com/rustls/webpki-roots) to permit the latest version.
- [Commits](https://github.com/rustls/webpki-roots/compare/v/0.23.1...v/0.24.0)

---
updated-dependencies:
- dependency-name: webpki-roots
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-07 08:51:20 +00:00
Alex Butler 8f73cf03ab update changelog 2023-06-17 23:50:18 +01:00
Alex Butler 9567cc73f3 Add panics docs 2023-06-17 23:46:04 +01:00
Alex Butler 7869f11b41 Add assert panics for WebSocketConfig 2023-06-17 23:36:46 +01:00
Alex Butler 2345e28158 Clarify WebSocketConfig docs 2023-06-17 23:30:12 +01:00
16 changed files with 523 additions and 75 deletions

View File

@ -43,6 +43,30 @@ jobs:
- name: Test - name: Test
run: cargo test --release run: cargo test --release
test-msvr:
name: Test MSRV
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- 1.60.0
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Install dependencies
run: sudo apt-get install libssl-dev
- name: Check
run: cargo check
autobahn: autobahn:
name: Autobahn tests name: Autobahn tests
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -1,4 +1,15 @@
# Unreleased (0.20.0) # 0.21.0
- Fix read-predominant auto pong responses not flushing when hitting WouldBlock errors.
- Improve `FrameHeader::format` write correctness.
- Update `rustls` to `0.22`.
- Update `webpki-roots` to `0.26`.
- Update `rustls-native-certs` to `0.7`.
- Update `http` to `1.0.0`.
# 0.20.1
- Fixes [CVE-2023-43669](https://github.com/snapview/tungstenite-rs/pull/379).
# 0.20.0
- Remove many implicit flushing behaviours. In general reading and writing messages will no - Remove many implicit flushing behaviours. In general reading and writing messages will no
longer flush until calling `flush`. An exception is automatic responses (e.g. pongs) longer flush until calling `flush`. An exception is automatic responses (e.g. pongs)
which will continue to be written and flushed when reading and writing. which will continue to be written and flushed when reading and writing.
@ -14,6 +25,7 @@
Note: `WriteBufferFull` returns the message that could not be written as a `Message::Frame`. Note: `WriteBufferFull` returns the message that could not be written as a `Message::Frame`.
- Add ability to buffer multiple writes before writing to the underlying stream, controlled by - Add ability to buffer multiple writes before writing to the underlying stream, controlled by
`WebSocketConfig::write_buffer_size` (default 128 KiB). Improves batch message write performance. `WebSocketConfig::write_buffer_size` (default 128 KiB). Improves batch message write performance.
- Panic on receiving invalid `WebSocketConfig`.
# 0.19.0 # 0.19.0

View File

@ -7,12 +7,12 @@ authors = ["Alexey Galakhov", "Daniel Abramov"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs" homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.19.0" documentation = "https://docs.rs/tungstenite/0.21.0"
repository = "https://github.com/snapview/tungstenite-rs" repository = "https://github.com/snapview/tungstenite-rs"
version = "0.19.0" version = "0.21.0"
edition = "2018" edition = "2018"
rust-version = "1.51" rust-version = "1.60"
include = ["benches/**/*", "src/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"] include = ["benches/**/*", "src/**/*", "examples/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"]
[package.metadata.docs.rs] [package.metadata.docs.rs]
all-features = true all-features = true
@ -24,13 +24,13 @@ native-tls = ["native-tls-crate"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"] native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"] rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"] rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls"] __rustls-tls = ["rustls", "rustls-pki-types"]
[dependencies] [dependencies]
data-encoding = { version = "2", optional = true } data-encoding = { version = "2", optional = true }
byteorder = "1.3.2" byteorder = "1.3.2"
bytes = "1.0" bytes = "1.0"
http = { version = "0.2", optional = true } http = { version = "1.0", optional = true }
httparse = { version = "1.3.4", optional = true } httparse = { version = "1.3.4", optional = true }
log = "0.4.8" log = "0.4.8"
rand = "0.8.0" rand = "0.8.0"
@ -46,22 +46,26 @@ version = "0.2.3"
[dependencies.rustls] [dependencies.rustls]
optional = true optional = true
version = "0.21.0" version = "0.22.0"
[dependencies.rustls-pki-types]
optional = true
version = "1.0"
[dependencies.rustls-native-certs] [dependencies.rustls-native-certs]
optional = true optional = true
version = "0.6.0" version = "0.7.0"
[dependencies.webpki-roots] [dependencies.webpki-roots]
optional = true optional = true
version = "0.23" version = "0.26"
[dev-dependencies] [dev-dependencies]
criterion = "0.5.0" criterion = "0.5.0"
env_logger = "0.10.0" env_logger = "0.10.0"
input_buffer = "0.5.0" input_buffer = "0.5.0"
net2 = "0.2.37"
rand = "0.8.4" rand = "0.8.4"
socket2 = "0.5.5"
[[bench]] [[bench]]
name = "buffer" name = "buffer"

View File

@ -77,7 +77,7 @@ There is no support for permessage-deflate at the moment, but the PRs are welcom
Testing Testing
------- -------
Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://github.com/crossbario/autobahn-testsuite) for
WebSockets. It is also covered by internal unit tests as well as possible. WebSockets. It is also covered by internal unit tests as well as possible.
Contributing Contributing

View File

@ -1,12 +1,13 @@
//! Methods to connect to a WebSocket as a client. //! Methods to connect to a WebSocket as a client.
use std::{ use std::{
convert::TryFrom,
io::{Read, Write}, io::{Read, Write},
net::{SocketAddr, TcpStream, ToSocketAddrs}, net::{SocketAddr, TcpStream, ToSocketAddrs},
result::Result as StdResult, result::Result as StdResult,
}; };
use http::{request::Parts, Uri}; use http::{request::Parts, HeaderName, Uri};
use log::*; use log::*;
use url::Url; use url::Url;
@ -265,3 +266,73 @@ impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
Request::from_httparse(self) Request::from_httparse(self)
} }
} }
/// Builder for a custom [`IntoClientRequest`] with options to add
/// custom additional headers and sub protocols.
///
/// # Example
///
/// ```rust no_run
/// # use crate::*;
/// use http::Uri;
/// use tungstenite::{connect, ClientRequestBuilder};
///
/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
/// let token = "my_jwt_token";
/// let builder = ClientRequestBuilder::new(uri)
/// .with_header("Authorization", format!("Bearer {token}"))
/// .with_sub_protocol("my_sub_protocol");
/// let socket = connect(builder).unwrap();
/// ```
#[derive(Debug, Clone)]
pub struct ClientRequestBuilder {
uri: Uri,
/// Additional [`Request`] handshake headers
additional_headers: Vec<(String, String)>,
/// Handsake subprotocols
subprotocols: Vec<String>,
}
impl ClientRequestBuilder {
/// Initializes an empty request builder
#[must_use]
pub const fn new(uri: Uri) -> Self {
Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
}
/// Adds (`key`, `value`) as an additional header to the handshake request
pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
where
K: Into<String>,
V: Into<String>,
{
self.additional_headers.push((key.into(), value.into()));
self
}
/// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
where
P: Into<String>,
{
self.subprotocols.push(protocol.into());
self
}
}
impl IntoClientRequest for ClientRequestBuilder {
fn into_client_request(self) -> Result<Request> {
let mut request = self.uri.into_client_request()?;
let headers = request.headers_mut();
for (k, v) in self.additional_headers {
let key = HeaderName::try_from(k)?;
let value = v.parse()?;
headers.append(key, value);
}
if !self.subprotocols.is_empty() {
let protocols = self.subprotocols.join(", ").parse()?;
headers.append("Sec-WebSocket-Protocol", protocols);
}
Ok(request)
}
}

View File

@ -59,6 +59,9 @@ pub enum Error {
/// UTF coding error. /// UTF coding error.
#[error("UTF-8 encoding error")] #[error("UTF-8 encoding error")]
Utf8, Utf8,
/// Attack attempt detected.
#[error("Attack attempt detected")]
AttackAttempt,
/// Invalid URL. /// Invalid URL.
#[error("URL error: {0}")] #[error("URL error: {0}")]
Url(#[from] UrlError), Url(#[from] UrlError),

View File

@ -20,7 +20,7 @@ pub struct HandshakeMachine<Stream> {
impl<Stream> HandshakeMachine<Stream> { impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer. /// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self { pub fn start_read(stream: Stream) -> Self {
HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) } Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
} }
/// Start writing data to the peer. /// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self { pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
@ -41,25 +41,31 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> { pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
trace!("Doing handshake round."); trace!("Doing handshake round.");
match self.state { match self.state {
HandshakeState::Reading(mut buf) => { HandshakeState::Reading(mut buf, mut attack_check) => {
let read = buf.read_from(&mut self.stream).no_block()?; let read = buf.read_from(&mut self.stream).no_block()?;
match read { match read {
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)), Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { Some(count) => {
buf.advance(size); attack_check.check_incoming_packet_size(count)?;
RoundResult::StageFinished(StageResult::DoneReading { // TODO: this is slow for big headers with too many small packets.
result: obj, // The parser has to be reworked in order to work on streams instead
stream: self.stream, // of buffers.
tail: buf.into_vec(), Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf, attack_check),
..self
})
}) })
} else { }
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}),
None => Ok(RoundResult::WouldBlock(HandshakeMachine { None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf), state: HandshakeState::Reading(buf, attack_check),
..self ..self
})), })),
} }
@ -119,7 +125,54 @@ pub trait TryParse: Sized {
#[derive(Debug)] #[derive(Debug)]
enum HandshakeState { enum HandshakeState {
/// Reading data from the peer. /// Reading data from the peer.
Reading(ReadBuffer), Reading(ReadBuffer, AttackCheck),
/// Sending data to the peer. /// Sending data to the peer.
Writing(Cursor<Vec<u8>>), Writing(Cursor<Vec<u8>>),
} }
/// Attack mitigation. Contains counters needed to prevent DoS attacks
/// and reject valid but useless headers.
#[derive(Debug)]
pub(crate) struct AttackCheck {
/// Number of HTTP header successful reads (TCP packets).
number_of_packets: usize,
/// Total number of bytes in HTTP header.
number_of_bytes: usize,
}
impl AttackCheck {
/// Initialize attack checking for incoming buffer.
fn new() -> Self {
Self { number_of_packets: 0, number_of_bytes: 0 }
}
/// Check the size of an incoming packet. To be called immediately after `read()`
/// passing its returned bytes count as `size`.
fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
self.number_of_packets += 1;
self.number_of_bytes += size;
// TODO: these values are hardcoded. Instead of making them configurable,
// rework the way HTTP header is parsed to remove this check at all.
const MAX_BYTES: usize = 65536;
const MAX_PACKETS: usize = 512;
const MIN_PACKET_SIZE: usize = 128;
const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
if self.number_of_bytes > MAX_BYTES {
return Err(Error::AttackAttempt);
}
if self.number_of_packets > MAX_PACKETS {
return Err(Error::AttackAttempt);
}
if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD
&& self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes
{
return Err(Error::AttackAttempt);
}
Ok(())
}
}

View File

@ -39,7 +39,7 @@ pub use crate::{
#[cfg(feature = "handshake")] #[cfg(feature = "handshake")]
pub use crate::{ pub use crate::{
client::{client, connect}, client::{client, connect, ClientRequestBuilder},
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError}, handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config}, server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
}; };

View File

@ -1,4 +1,4 @@
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{NetworkEndian, ReadBytesExt};
use log::*; use log::*;
use std::{ use std::{
borrow::Cow, borrow::Cow,
@ -108,8 +108,12 @@ impl FrameHeader {
output.write_all(&[one, two])?; output.write_all(&[one, two])?;
match lenfmt { match lenfmt {
LengthFormat::U8(_) => (), LengthFormat::U8(_) => (),
LengthFormat::U16 => output.write_u16::<NetworkEndian>(length as u16)?, LengthFormat::U16 => {
LengthFormat::U64 => output.write_u64::<NetworkEndian>(length)?, output.write_all(&(length as u16).to_be_bytes())?;
}
LengthFormat::U64 => {
output.write_all(&length.to_be_bytes())?;
}
} }
if let Some(ref mask) = self.mask { if let Some(ref mask) = self.mask {
@ -295,7 +299,7 @@ impl Frame {
1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => { _ => {
let mut data = self.payload; let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into(); let code = u16::from_be_bytes([data[0], data[1]]).into();
data.drain(0..2); data.drain(0..2);
let text = String::from_utf8(data)?; let text = String::from_utf8(data)?;
Ok(Some(CloseFrame { code, reason: text.into() })) Ok(Some(CloseFrame { code, reason: text.into() }))
@ -340,7 +344,7 @@ impl Frame {
pub fn close(msg: Option<CloseFrame>) -> Frame { pub fn close(msg: Option<CloseFrame>) -> Frame {
let payload = if let Some(CloseFrame { code, reason }) = msg { let payload = if let Some(CloseFrame { code, reason }) = msg {
let mut p = Vec::with_capacity(reason.as_bytes().len() + 2); let mut p = Vec::with_capacity(reason.as_bytes().len() + 2);
p.write_u16::<NetworkEndian>(code.into()).unwrap(); // can't fail p.extend(u16::from(code).to_be_bytes());
p.extend_from_slice(reason.as_bytes()); p.extend_from_slice(reason.as_bytes());
p p
} else { } else {
@ -366,6 +370,8 @@ impl Frame {
impl fmt::Display for Frame { impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use std::fmt::Write;
write!( write!(
f, f,
" "
@ -385,7 +391,10 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(), self.len(),
self.payload.len(), self.payload.len(),
self.payload.iter().map(|byte| format!("{:02x}", byte)).collect::<String>() self.payload.iter().fold(String::new(), |mut output, byte| {
_ = write!(output, "{byte:02x}");
output
})
) )
} }
} }

View File

@ -48,7 +48,7 @@ mod tests {
#[test] #[test]
fn test_apply_mask() { fn test_apply_mask() {
let mask = [0x6d, 0xb6, 0xb2, 0x80]; let mask = [0x6d, 0xb6, 0xb2, 0x80];
let unmasked = vec![ let unmasked = [
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9,
0x12, 0x03, 0x12, 0x03,
]; ];

View File

@ -13,13 +13,10 @@ use self::{
}, },
message::{IncompleteMessage, IncompleteMessageType}, message::{IncompleteMessage, IncompleteMessageType},
}; };
use crate::{ use crate::error::{Error, ProtocolError, Result};
error::{Error, ProtocolError, Result},
util::NonBlockingResult,
};
use log::*; use log::*;
use std::{ use std::{
io::{ErrorKind as IoErrorKind, Read, Write}, io::{self, Read, Write},
mem::replace, mem::replace,
}; };
@ -42,19 +39,27 @@ pub struct WebSocketConfig {
/// to the underlying stream. /// to the underlying stream.
/// The default value is 128 KiB. /// The default value is 128 KiB.
/// ///
/// If set to `0` each message will be eagerly written to the underlying stream.
/// It is often more optimal to allow them to buffer a little, hence the default value.
///
/// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless. /// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless.
pub write_buffer_size: usize, pub write_buffer_size: usize,
/// The max size of the write buffer in bytes. Setting this can provide backpressure /// The max size of the write buffer in bytes. Setting this can provide backpressure
/// in the case the write buffer is filling up due to write errors. /// in the case the write buffer is filling up due to write errors.
/// The default value is unlimited. /// The default value is unlimited.
/// ///
/// Note: Should always be set higher than [`write_buffer_size`](Self::write_buffer_size). /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
/// when writes to the underlying stream are failing. So the **write buffer can not
/// fill up if you are not observing write errors even if not flushing**.
///
/// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
/// and probably a little more depending on error handling strategy.
pub max_write_buffer_size: usize, pub max_write_buffer_size: usize,
/// The maximum size of a message. `None` means no size limit. The default value is 64 MiB /// The maximum size of an incoming message. `None` means no size limit. The default value is 64 MiB
/// which should be reasonably big for all normal use-cases but small enough to prevent /// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user. /// memory eating by a malicious user.
pub max_message_size: Option<usize>, pub max_message_size: Option<usize>,
/// The maximum size of a single message frame. `None` means no size limit. The limit is for /// The maximum size of a single incoming message frame. `None` means no size limit. The limit is for
/// frame payload NOT including the frame header. The default value is 16 MiB which should /// frame payload NOT including the frame header. The default value is 16 MiB which should
/// be reasonably big for all normal use-cases but small enough to prevent memory eating /// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user. /// by a malicious user.
@ -81,6 +86,17 @@ impl Default for WebSocketConfig {
} }
} }
impl WebSocketConfig {
/// Panic if values are invalid.
pub(crate) fn assert_valid(&self) {
assert!(
self.max_write_buffer_size > self.write_buffer_size,
"WebSocketConfig::max_write_buffer_size must be greater than write_buffer_size, \
see WebSocketConfig docs`"
);
}
}
/// WebSocket input-output stream. /// WebSocket input-output stream.
/// ///
/// This is THE structure you want to create to be able to speak the WebSocket protocol. /// This is THE structure you want to create to be able to speak the WebSocket protocol.
@ -101,6 +117,9 @@ impl<Stream> WebSocket<Stream> {
/// Call this function if you're using Tungstenite as a part of a web framework /// Call this function if you're using Tungstenite as a part of a web framework
/// or together with an existing one. If you need an initial handshake, use /// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket. /// `connect()` or `accept()` functions of the crate to construct a websocket.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self { pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) } WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
} }
@ -110,6 +129,9 @@ impl<Stream> WebSocket<Stream> {
/// Call this function if you're using Tungstenite as a part of a web framework /// Call this function if you're using Tungstenite as a part of a web framework
/// or together with an existing one. If you need an initial handshake, use /// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket. /// `connect()` or `accept()` functions of the crate to construct a websocket.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn from_partially_read( pub fn from_partially_read(
stream: Stream, stream: Stream,
part: Vec<u8>, part: Vec<u8>,
@ -132,6 +154,9 @@ impl<Stream> WebSocket<Stream> {
} }
/// Change the configuration. /// Change the configuration.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
self.context.set_config(set_func) self.context.set_config(set_func)
} }
@ -285,22 +310,32 @@ pub struct WebSocketContext {
incomplete: Option<IncompleteMessage>, incomplete: Option<IncompleteMessage>,
/// Send in addition to regular messages E.g. "pong" or "close". /// Send in addition to regular messages E.g. "pong" or "close".
additional_send: Option<Frame>, additional_send: Option<Frame>,
/// True indicates there is an additional message (like a pong)
/// that failed to flush previously and we should try again.
unflushed_additional: bool,
/// The configuration for the websocket session. /// The configuration for the websocket session.
config: WebSocketConfig, config: WebSocketConfig,
} }
impl WebSocketContext { impl WebSocketContext {
/// Create a WebSocket context that manages a post-handshake stream. /// Create a WebSocket context that manages a post-handshake stream.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self { pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::new(), config.unwrap_or_default()) Self::_new(role, FrameCodec::new(), config.unwrap_or_default())
} }
/// Create a WebSocket context that manages an post-handshake stream. /// Create a WebSocket context that manages an post-handshake stream.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self { pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::from_partially_read(part), config.unwrap_or_default()) Self::_new(role, FrameCodec::from_partially_read(part), config.unwrap_or_default())
} }
fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self { fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
config.assert_valid();
frame.set_max_out_buffer_len(config.max_write_buffer_size); frame.set_max_out_buffer_len(config.max_write_buffer_size);
frame.set_out_buffer_write_len(config.write_buffer_size); frame.set_out_buffer_write_len(config.write_buffer_size);
Self { Self {
@ -309,13 +344,18 @@ impl WebSocketContext {
state: WebSocketState::Active, state: WebSocketState::Active,
incomplete: None, incomplete: None,
additional_send: None, additional_send: None,
unflushed_additional: false,
config, config,
} }
} }
/// Change the configuration. /// Change the configuration.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) { pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config); set_func(&mut self.config);
self.config.assert_valid();
self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size); self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size);
self.frame.set_out_buffer_write_len(self.config.write_buffer_size); self.frame.set_out_buffer_write_len(self.config.write_buffer_size);
} }
@ -352,10 +392,16 @@ impl WebSocketContext {
self.state.check_not_terminated()?; self.state.check_not_terminated()?;
loop { loop {
if self.additional_send.is_some() { if self.additional_send.is_some() || self.unflushed_additional {
// Since we may get ping or close, we need to reply to the messages even during read. // Since we may get ping or close, we need to reply to the messages even during read.
// Thus we flush but ignore its blocking. match self.flush(stream) {
self.flush(stream).no_block()?; Ok(_) => {}
Err(Error::Io(err)) if err.kind() == io::ErrorKind::WouldBlock => {
// If blocked continue reading, but try again later
self.unflushed_additional = true;
}
Err(err) => return Err(err),
}
} else if self.role == Role::Server && !self.state.can_read() { } else if self.role == Role::Server && !self.state.can_read() {
self.state = WebSocketState::Terminated; self.state = WebSocketState::Terminated;
return Err(Error::ConnectionClosed); return Err(Error::ConnectionClosed);
@ -423,7 +469,9 @@ impl WebSocketContext {
{ {
self._write(stream, None)?; self._write(stream, None)?;
self.frame.write_out_buffer(stream)?; self.frame.write_out_buffer(stream)?;
Ok(stream.flush()?) stream.flush()?;
self.unflushed_additional = false;
Ok(())
} }
/// Writes any data in the out_buffer, `additional_send` and given `data`. /// Writes any data in the out_buffer, `additional_send` and given `data`.
@ -456,7 +504,7 @@ impl WebSocketContext {
Ok(_) => true, Ok(_) => true,
} }
} else { } else {
false self.unflushed_additional
}; };
// If we're closing and there is nothing to send anymore, we should close the connection. // If we're closing and there is nothing to send anymore, we should close the connection.
@ -735,7 +783,7 @@ impl<T> CheckConnectionReset for Result<T> {
fn check_connection_reset(self, state: WebSocketState) -> Self { fn check_connection_reset(self, state: WebSocketState) -> Self {
match self { match self {
Err(Error::Io(io_error)) => Err({ Err(Error::Io(io_error)) => Err({
if !state.can_read() && io_error.kind() == IoErrorKind::ConnectionReset { if !state.can_read() && io_error.kind() == io::ErrorKind::ConnectionReset {
Error::ConnectionClosed Error::ConnectionClosed
} else { } else {
Error::Io(io_error) Error::Io(io_error)

View File

@ -70,7 +70,8 @@ mod encryption {
#[cfg(feature = "__rustls-tls")] #[cfg(feature = "__rustls-tls")]
pub mod rustls { pub mod rustls {
use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned}; use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
use rustls_pki_types::ServerName;
use std::{ use std::{
convert::TryFrom, convert::TryFrom,
@ -105,36 +106,26 @@ mod encryption {
#[cfg(feature = "rustls-tls-native-roots")] #[cfg(feature = "rustls-tls-native-roots")]
{ {
let native_certs = rustls_native_certs::load_native_certs()?; let native_certs = rustls_native_certs::load_native_certs()?;
let der_certs: Vec<Vec<u8>> = let total_number = native_certs.len();
native_certs.into_iter().map(|cert| cert.0).collect();
let total_number = der_certs.len();
let (number_added, number_ignored) = let (number_added, number_ignored) =
root_store.add_parsable_certificates(&der_certs); root_store.add_parsable_certificates(native_certs);
log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})"); log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
} }
#[cfg(feature = "rustls-tls-webpki-roots")] #[cfg(feature = "rustls-tls-webpki-roots")]
{ {
root_store.add_server_trust_anchors( root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
);
} }
Arc::new( Arc::new(
ClientConfig::builder() ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store) .with_root_certificates(root_store)
.with_no_client_auth(), .with_no_client_auth(),
) )
} }
}; };
let domain = let domain = ServerName::try_from(domain)
ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?; .map_err(|_| TlsError::InvalidDnsName)?
.to_owned();
let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?; let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
let stream = StreamOwned::new(client, socket); let stream = StreamOwned::new(client, socket);

130
tests/auto_pong_flush.rs Normal file
View File

@ -0,0 +1,130 @@
use std::{
io::{self, Cursor, Read, Write},
mem,
};
use tungstenite::{
protocol::frame::{
coding::{Control, OpCode},
Frame, FrameHeader,
},
Message, WebSocket,
};
const NUMBER_OF_FLUSHES_TO_GET_IT_TO_WORK: usize = 3;
/// `Read`/`Write` mock.
/// * Reads a single ping, then returns `WouldBlock` forever after.
/// * Writes work fine.
/// * Flush `WouldBlock` twice then works on the 3rd attempt.
#[derive(Debug, Default)]
struct MockWrite {
/// Data written, but not flushed.
written_data: Vec<u8>,
/// The latest successfully flushed data.
flushed_data: Vec<u8>,
write_calls: usize,
flush_calls: usize,
read_calls: usize,
}
impl Read for MockWrite {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
self.read_calls += 1;
if self.read_calls == 1 {
let ping = Frame::ping(vec![]);
let len = ping.len();
ping.format(&mut buf).expect("format failed");
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, "nothing else to read"))
}
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_calls += 1;
self.written_data.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.flush_calls += 1;
if self.flush_calls % NUMBER_OF_FLUSHES_TO_GET_IT_TO_WORK == 0 {
mem::swap(&mut self.written_data, &mut self.flushed_data);
self.written_data.clear();
eprintln!("flush success");
Ok(())
} else {
eprintln!("flush would block");
Err(io::Error::new(io::ErrorKind::WouldBlock, "try again"))
}
}
}
/// Test for auto pong write & flushing behaviour.
///
/// In read-only/read-predominant usage auto pong responses should be written and flushed
/// even if WouldBlock errors are encountered.
#[test]
fn read_usage_auto_pong_flush() {
let mut ws =
WebSocket::from_raw_socket(MockWrite::default(), tungstenite::protocol::Role::Client, None);
// Receiving a ping should auto scheduled a pong on next read or write (but not written yet).
let msg = ws.read().unwrap();
assert!(matches!(msg, Message::Ping(_)), "Unexpected msg {:?}", msg);
assert_eq!(ws.get_ref().read_calls, 1);
assert!(ws.get_ref().written_data.is_empty(), "Unexpected {:?}", ws.get_ref());
assert!(ws.get_ref().flushed_data.is_empty(), "Unexpected {:?}", ws.get_ref());
// Next read fails as there is nothing else to read.
// This read call should have tried to write & flush a pong response, with the flush WouldBlock-ing
let next = ws.read().unwrap_err();
assert!(
matches!(next, tungstenite::Error::Io(ref err) if err.kind() == io::ErrorKind::WouldBlock),
"Unexpected read err {:?}",
next
);
assert_eq!(ws.get_ref().read_calls, 2);
assert!(!ws.get_ref().written_data.is_empty(), "Should have written a pong frame");
assert_eq!(ws.get_ref().write_calls, 1);
let pong_header =
FrameHeader::parse(&mut Cursor::new(&ws.get_ref().written_data)).unwrap().unwrap().0;
assert_eq!(pong_header.opcode, OpCode::Control(Control::Pong));
let written_data = ws.get_ref().written_data.clone();
assert_eq!(ws.get_ref().flush_calls, 1);
assert!(ws.get_ref().flushed_data.is_empty(), "Unexpected {:?}", ws.get_ref());
// Next read fails as before.
// This read call should try to flush the pong again, which again WouldBlock
let next = ws.read().unwrap_err();
assert!(
matches!(next, tungstenite::Error::Io(ref err) if err.kind() == io::ErrorKind::WouldBlock),
"Unexpected read err {:?}",
next
);
assert_eq!(ws.get_ref().read_calls, 3);
assert_eq!(ws.get_ref().write_calls, 1);
assert_eq!(ws.get_ref().flush_calls, 2);
assert!(ws.get_ref().flushed_data.is_empty(), "Unexpected {:?}", ws.get_ref());
// Next read fails as before.
// This read call should try to flush the pong again, 3rd flush attempt is the charm
let next = ws.read().unwrap_err();
assert!(
matches!(next, tungstenite::Error::Io(ref err) if err.kind() == io::ErrorKind::WouldBlock),
"Unexpected read err {:?}",
next
);
assert_eq!(ws.get_ref().read_calls, 4);
assert_eq!(ws.get_ref().write_calls, 1);
assert_eq!(ws.get_ref().flush_calls, 3);
assert!(ws.get_ref().flushed_data == written_data, "Unexpected {:?}", ws.get_ref());
// On following read calls no additional writes or flushes are necessary
ws.read().unwrap_err();
assert_eq!(ws.get_ref().read_calls, 5);
assert_eq!(ws.get_ref().write_calls, 1);
assert_eq!(ws.get_ref().flush_calls, 3);
}

96
tests/client_headers.rs Normal file
View File

@ -0,0 +1,96 @@
#![cfg(feature = "handshake")]
use http::Uri;
use std::{
net::TcpListener,
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{
accept_hdr, connect,
handshake::server::{Request, Response},
ClientRequestBuilder, Error, Message,
};
/// Test for write buffering and flushing behaviour.
#[test]
fn test_headers() {
env_logger::init();
let uri: Uri = "ws://127.0.0.1:3013/socket".parse().unwrap();
let token = "my_jwt_token";
let full_token = format!("Bearer {token}");
let sub_protocol = "my_sub_protocol";
let builder = ClientRequestBuilder::new(uri)
.with_header("Authorization", full_token.to_owned())
.with_sub_protocol(sub_protocol.to_owned());
spawn(|| {
sleep(Duration::from_secs(5));
println!("Unit test executed too long, perhaps stuck on WOULDBLOCK...");
exit(1);
});
let server = TcpListener::bind("127.0.0.1:3013").unwrap();
let client_thread = spawn(move || {
let (mut client, _) = connect(builder).unwrap();
client.send(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read().unwrap(); // receive close from server
assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
let callback = |req: &Request, mut response: Response| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
let authorization_header: String = "authorization".to_ascii_lowercase();
let web_socket_proto: String = "sec-websocket-protocol".to_ascii_lowercase();
for (ref header, value) in req.headers() {
println!("* {}: {}", header, value.to_str().unwrap());
if header.to_string() == authorization_header {
println!("Matching authorization header");
assert_eq!(header.to_string(), authorization_header);
assert_eq!(value.to_str().unwrap(), full_token);
} else if header.to_string() == web_socket_proto {
println!("Matching sec-websocket-protocol header");
assert_eq!(header.to_string(), web_socket_proto);
assert_eq!(value.to_str().unwrap(), sub_protocol);
// the server needs to respond with the same sub-protocol
response
.headers_mut()
.append("sec-websocket-protocol", sub_protocol.parse().unwrap());
}
}
Ok(response)
};
let client_handler = server.incoming().next().unwrap();
let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
client_handler.close(None).unwrap(); // send close to client
// This read should succeed even though we already initiated a close
let message = client_handler.read().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement
let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
drop(client_handler);
client_thread.join().unwrap();
}

View File

@ -10,7 +10,7 @@ use std::{
time::Duration, time::Duration,
}; };
use net2::TcpStreamExt; use socket2::Socket;
use tungstenite::{accept, connect, stream::MaybeTlsStream, Error, Message, WebSocket}; use tungstenite::{accept, connect, stream::MaybeTlsStream, Error, Message, WebSocket};
use url::Url; use url::Url;
@ -19,7 +19,7 @@ type Sock = WebSocket<MaybeTlsStream<TcpStream>>;
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST) fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
where where
CT: FnOnce(Sock) + Send + 'static, CT: FnOnce(Sock) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream>), ST: FnOnce(WebSocket<Socket>),
{ {
env_logger::try_init().ok(); env_logger::try_init().ok();
@ -40,7 +40,7 @@ where
}); });
let client_handler = server.incoming().next().unwrap(); let client_handler = server.incoming().next().unwrap();
let client_handler = accept(client_handler.unwrap()).unwrap(); let client_handler = accept(Socket::from(client_handler.unwrap())).unwrap();
server_task(client_handler); server_task(client_handler);

View File

@ -1,5 +1,7 @@
#![cfg(feature = "handshake")]
use std::net::TcpListener; use std::net::TcpListener;
use std::thread::spawn; use std::thread::{sleep, spawn};
use std::time::Duration;
use tungstenite::error::{Error, ProtocolError, SubProtocolError}; use tungstenite::error::{Error, ProtocolError, SubProtocolError};
use tungstenite::handshake::client::generate_key; use tungstenite::handshake::client::generate_key;
use tungstenite::handshake::server::{Request, Response}; use tungstenite::handshake::server::{Request, Response};
@ -35,7 +37,6 @@ fn server_thread(port: u16, server_subprotocols: Option<Vec<String>>) {
spawn(move || { spawn(move || {
let server = TcpListener::bind(("127.0.0.1", port)) let server = TcpListener::bind(("127.0.0.1", port))
.expect("Can't listen, is this port already in use?"); .expect("Can't listen, is this port already in use?");
let client_handler = server.incoming().next().unwrap();
let callback = |_request: &Request, mut response: Response| { let callback = |_request: &Request, mut response: Response| {
if let Some(subprotocols) = server_subprotocols { if let Some(subprotocols) = server_subprotocols {
@ -45,13 +46,16 @@ fn server_thread(port: u16, server_subprotocols: Option<Vec<String>>) {
Ok(response) Ok(response)
}; };
let _client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap(); let client_handler = server.incoming().next().unwrap();
let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
client_handler.close(None).unwrap();
}); });
} }
#[test] #[test]
fn test_server_send_no_subprotocol() { fn test_server_send_no_subprotocol() {
server_thread(3012, None); server_thread(3012, None);
sleep(Duration::from_secs(1));
let err = let err =
connect(create_http_request("ws://127.0.0.1:3012", Some(vec!["my-sub-protocol".into()]))) connect(create_http_request("ws://127.0.0.1:3012", Some(vec!["my-sub-protocol".into()])))
@ -68,6 +72,7 @@ fn test_server_send_no_subprotocol() {
#[test] #[test]
fn test_server_sent_subprotocol_none_requested() { fn test_server_sent_subprotocol_none_requested() {
server_thread(3013, Some(vec!["my-sub-protocol".to_string()])); server_thread(3013, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let err = connect(create_http_request("ws://127.0.0.1:3013", None)).unwrap_err(); let err = connect(create_http_request("ws://127.0.0.1:3013", None)).unwrap_err();
@ -82,6 +87,7 @@ fn test_server_sent_subprotocol_none_requested() {
#[test] #[test]
fn test_invalid_subprotocol() { fn test_invalid_subprotocol() {
server_thread(3014, Some(vec!["invalid-sub-protocol".to_string()])); server_thread(3014, Some(vec!["invalid-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let err = connect(create_http_request( let err = connect(create_http_request(
"ws://127.0.0.1:3014", "ws://127.0.0.1:3014",
@ -100,7 +106,7 @@ fn test_invalid_subprotocol() {
#[test] #[test]
fn test_request_multiple_subprotocols() { fn test_request_multiple_subprotocols() {
server_thread(3015, Some(vec!["my-sub-protocol".to_string()])); server_thread(3015, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let (_, response) = connect(create_http_request( let (_, response) = connect(create_http_request(
"ws://127.0.0.1:3015", "ws://127.0.0.1:3015",
Some(vec![ Some(vec![
@ -120,6 +126,7 @@ fn test_request_multiple_subprotocols() {
#[test] #[test]
fn test_request_single_subprotocol() { fn test_request_single_subprotocol() {
server_thread(3016, Some(vec!["my-sub-protocol".to_string()])); server_thread(3016, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let (_, response) = connect(create_http_request( let (_, response) = connect(create_http_request(
"ws://127.0.0.1:3016", "ws://127.0.0.1:3016",