Compare commits

...

42 Commits

Author SHA1 Message Date
Conor ac3ce4d5a2
Merge pull request #1 from Its-Just-Nans/master
Change tests to merge https://github.com/snapview/tungstenite-rs/pull/363
2024-04-26 10:50:51 +10:00
n4n5 adbc70a6b4
fix test for merging 2024-03-08 15:37:30 +01:00
n4n5 b44721662d
Merge remote-tracking branch 'newone/master' 2024-03-08 15:35:34 +01:00
Félix Lescaudey de Maneville 0fa41973b4
Add builder for additional header values (#400)
* ADd builder for additional header values

* Update client.rs

* fix: docs

* feat: add test

* fix: typo

* add

---------

Co-authored-by: n4n5 <56606507+Its-Just-Nans@users.noreply.github.com>
Co-authored-by: n4n5 <its.just.n4n5@gmail.com>
2024-02-12 20:56:15 +01:00
Alex Butler 2ee05d1080 Update 0.21.0 changelog 2023-12-12 13:10:10 +01:00
Daniel Abramov 3df40fd0f1 Update CHANGELOG.md
Fixes #398, #389.
2023-12-11 12:47:40 +01:00
Sebastian Dröge e9604ac35b Update MSRV to 1.60 and check on the CI
The code does not actually compile anymore with 1.51.

Also only run `cargo check` with 1.60 as various optional features,
tests, benchmarks actually require newer Rust versions.
2023-12-08 11:59:01 +01:00
Alexey Galakhov 85463b264e Release version 0.21.0 2023-12-07 02:03:01 +01:00
Constantin Nickel bcd7f85e65 Update `rustls` to 0.22 2023-12-05 22:24:29 +01:00
Alex Butler 9f0af2a2e3 Test that no additional flushes are called after pong flush success 2023-12-05 22:23:31 +01:00
Alex Butler 2d5b3e18de Fix auto pong responses not flushing after block
Retry pong flushes on read.
Add read_usage_auto_pong_flush scenario test
2023-12-05 22:23:31 +01:00
Alex Butler a54623ccfe Remove proposed version from changelog 2023-12-02 20:30:48 +01:00
Alex Butler 866ce20dbe Update webpki-roots to 0.26 2023-12-02 20:30:48 +01:00
Alex Butler 0f6e6517e6 Fix FrameHeader::format write & other lints 2023-12-02 00:01:17 +01:00
Alexey Galakhov fc17f7341d
Merge pull request #373 from psychon/reduce-byteorder
Reduce use of byteorder crate
2023-11-17 19:38:27 +01:00
Daniel Abramov a43bb499df
Merge pull request #386 from snapview/dependabot/cargo/http-1.0
Update http requirement from 0.2 to 1.0
2023-11-17 18:54:04 +01:00
Alexey Galakhov 08cdd76dd6
Merge pull request #387 from nickelc/deps/socket2
Replace deprecated `net2` dev dependency with `socket2`
2023-11-17 02:12:31 +01:00
Constantin Nickel 8ca9b2314c Replace deprecated `net2` dev dependency with `socket2` 2023-11-16 11:33:58 +01:00
dependabot[bot] 2867907f15
Update http requirement from 0.2 to 1.0
Updates the requirements on [http](https://github.com/hyperium/http) to permit the latest version.
- [Release notes](https://github.com/hyperium/http/releases)
- [Changelog](https://github.com/hyperium/http/blob/master/CHANGELOG.md)
- [Commits](https://github.com/hyperium/http/compare/v0.2.0...v1.0.0)

---
updated-dependencies:
- dependency-name: http
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-11-16 08:04:57 +00:00
Alexey Galakhov a6cbed7bff
Merge pull request #385 from nickelc/deps/webpki-roots
Update `webpki-roots` to 0.25
2023-11-15 22:30:51 +01:00
Constantin Nickel 6c61d54ad2 Update `webpki-roots` to 0.25 2023-11-15 19:56:28 +01:00
Daniel Abramov 272d83c430 doc: clarify the meaning of config values 2023-10-29 14:09:32 +01:00
Daniel Abramov 3d9fd1e5cb
Merge pull request #383 from atouchet/cbar
Update Autobahn Test Suite page
2023-10-22 22:49:07 +02:00
Alex Touchet 614200a2fa
Update Autobahn Test Suite page 2023-10-16 19:36:40 -07:00
Daniel Abramov 8b3ecd3cc0 Update changelog 2023-09-23 14:10:53 +02:00
Daniel Abramov 722850d473 Fix incorrect metadata in Cargo.toml 2023-09-23 14:10:04 +02:00
Alexey Galakhov 219075edaa
Merge pull request #379 from snapview/CVE-2023-43669
Quick-and-dirty fix for CVE-2023-43669
2023-09-23 02:33:24 +02:00
Alexey Galakhov f0f1a06a50 Bump crate version 2023-09-23 02:16:16 +02:00
Alexey Galakhov 2e5029284b Add checking for header sanity
Co-authored-by: Daniel Abramov <inetcrack2@gmail.com>
2023-09-23 02:16:09 +02:00
Alexey Galakhov f916b332a9 Add `AttackAttempt` error variant 2023-09-22 17:48:31 +02:00
Uli Schlachter e4224ed85a Reduce use of byteorder crate
The byteorder dependency is only used in protocol::frame::frame. I
thought this dependency could easily be removed and set out to replace
the use of byteorder with equivalent std methods.

NetworkEndian is an alias for BigEndian. Converting a number like u32 to
bytes can be done via the std lib via .to_be_bytes(). The opposite
direction is from_by_bytes(). These simple things thus to not need
byteorder.

There is one place in the code where byteorder actually helps, thus this
dependency is not actually fully removed. ByteOrder::read_uint() allows
to read 1 to 8 bytes of data and returns the result as u64. Doing this
with the standard library basically requires re-implementing byteorder.
Thus, I did not do that.

Signed-off-by: Uli Schlachter <psychon@znc.in>
2023-08-18 10:35:33 +02:00
Daniel Abramov 53914c1180 Include examples so that `cargo publish` works 2023-07-22 15:36:37 +01:00
Daniel Abramov 5323559891 Bump version 2023-07-22 15:19:04 +01:00
Daniel Abramov 6e63b17b63 Update changelog 2023-07-22 15:18:14 +01:00
Daniel Abramov f2ed7aa826
Merge pull request #365 from snapview/dependabot/cargo/webpki-roots-0.24
Update webpki-roots requirement from 0.23 to 0.24
2023-07-22 16:10:14 +02:00
Daniel Abramov 8d8f0da204
Merge pull request #362 from alexheretic/config-asserts
Add assert panics for `WebSocketConfig`
2023-07-22 16:08:30 +02:00
Daniel Abramov dac07ea68b
Merge pull request #361 from alexheretic/docs++
Clarify `WebSocketConfig` docs
2023-07-22 16:07:39 +02:00
dependabot[bot] 40cd43c4f9
Update webpki-roots requirement from 0.23 to 0.24
Updates the requirements on [webpki-roots](https://github.com/rustls/webpki-roots) to permit the latest version.
- [Commits](https://github.com/rustls/webpki-roots/compare/v/0.23.1...v/0.24.0)

---
updated-dependencies:
- dependency-name: webpki-roots
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-07 08:51:20 +00:00
Alex Butler 8f73cf03ab update changelog 2023-06-17 23:50:18 +01:00
Alex Butler 9567cc73f3 Add panics docs 2023-06-17 23:46:04 +01:00
Alex Butler 7869f11b41 Add assert panics for WebSocketConfig 2023-06-17 23:36:46 +01:00
Alex Butler 2345e28158 Clarify WebSocketConfig docs 2023-06-17 23:30:12 +01:00
16 changed files with 523 additions and 75 deletions

View File

@ -43,6 +43,30 @@ jobs:
- name: Test
run: cargo test --release
test-msvr:
name: Test MSRV
runs-on: ubuntu-latest
strategy:
matrix:
rust:
- 1.60.0
steps:
- name: Checkout sources
uses: actions/checkout@v3
- name: Install toolchain
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
- name: Install dependencies
run: sudo apt-get install libssl-dev
- name: Check
run: cargo check
autobahn:
name: Autobahn tests
runs-on: ubuntu-latest

View File

@ -1,4 +1,15 @@
# Unreleased (0.20.0)
# 0.21.0
- Fix read-predominant auto pong responses not flushing when hitting WouldBlock errors.
- Improve `FrameHeader::format` write correctness.
- Update `rustls` to `0.22`.
- Update `webpki-roots` to `0.26`.
- Update `rustls-native-certs` to `0.7`.
- Update `http` to `1.0.0`.
# 0.20.1
- Fixes [CVE-2023-43669](https://github.com/snapview/tungstenite-rs/pull/379).
# 0.20.0
- Remove many implicit flushing behaviours. In general reading and writing messages will no
longer flush until calling `flush`. An exception is automatic responses (e.g. pongs)
which will continue to be written and flushed when reading and writing.
@ -14,6 +25,7 @@
Note: `WriteBufferFull` returns the message that could not be written as a `Message::Frame`.
- Add ability to buffer multiple writes before writing to the underlying stream, controlled by
`WebSocketConfig::write_buffer_size` (default 128 KiB). Improves batch message write performance.
- Panic on receiving invalid `WebSocketConfig`.
# 0.19.0

View File

@ -7,12 +7,12 @@ authors = ["Alexey Galakhov", "Daniel Abramov"]
license = "MIT OR Apache-2.0"
readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.19.0"
documentation = "https://docs.rs/tungstenite/0.21.0"
repository = "https://github.com/snapview/tungstenite-rs"
version = "0.19.0"
version = "0.21.0"
edition = "2018"
rust-version = "1.51"
include = ["benches/**/*", "src/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"]
rust-version = "1.60"
include = ["benches/**/*", "src/**/*", "examples/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"]
[package.metadata.docs.rs]
all-features = true
@ -24,13 +24,13 @@ native-tls = ["native-tls-crate"]
native-tls-vendored = ["native-tls", "native-tls-crate/vendored"]
rustls-tls-native-roots = ["__rustls-tls", "rustls-native-certs"]
rustls-tls-webpki-roots = ["__rustls-tls", "webpki-roots"]
__rustls-tls = ["rustls"]
__rustls-tls = ["rustls", "rustls-pki-types"]
[dependencies]
data-encoding = { version = "2", optional = true }
byteorder = "1.3.2"
bytes = "1.0"
http = { version = "0.2", optional = true }
http = { version = "1.0", optional = true }
httparse = { version = "1.3.4", optional = true }
log = "0.4.8"
rand = "0.8.0"
@ -46,22 +46,26 @@ version = "0.2.3"
[dependencies.rustls]
optional = true
version = "0.21.0"
version = "0.22.0"
[dependencies.rustls-pki-types]
optional = true
version = "1.0"
[dependencies.rustls-native-certs]
optional = true
version = "0.6.0"
version = "0.7.0"
[dependencies.webpki-roots]
optional = true
version = "0.23"
version = "0.26"
[dev-dependencies]
criterion = "0.5.0"
env_logger = "0.10.0"
input_buffer = "0.5.0"
net2 = "0.2.37"
rand = "0.8.4"
socket2 = "0.5.5"
[[bench]]
name = "buffer"

View File

@ -77,7 +77,7 @@ There is no support for permessage-deflate at the moment, but the PRs are welcom
Testing
-------
Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://crossbar.io/autobahn/) for
Tungstenite is thoroughly tested and passes the [Autobahn Test Suite](https://github.com/crossbario/autobahn-testsuite) for
WebSockets. It is also covered by internal unit tests as well as possible.
Contributing

View File

@ -1,12 +1,13 @@
//! Methods to connect to a WebSocket as a client.
use std::{
convert::TryFrom,
io::{Read, Write},
net::{SocketAddr, TcpStream, ToSocketAddrs},
result::Result as StdResult,
};
use http::{request::Parts, Uri};
use http::{request::Parts, HeaderName, Uri};
use log::*;
use url::Url;
@ -265,3 +266,73 @@ impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
Request::from_httparse(self)
}
}
/// Builder for a custom [`IntoClientRequest`] with options to add
/// custom additional headers and sub protocols.
///
/// # Example
///
/// ```rust no_run
/// # use crate::*;
/// use http::Uri;
/// use tungstenite::{connect, ClientRequestBuilder};
///
/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
/// let token = "my_jwt_token";
/// let builder = ClientRequestBuilder::new(uri)
/// .with_header("Authorization", format!("Bearer {token}"))
/// .with_sub_protocol("my_sub_protocol");
/// let socket = connect(builder).unwrap();
/// ```
#[derive(Debug, Clone)]
pub struct ClientRequestBuilder {
uri: Uri,
/// Additional [`Request`] handshake headers
additional_headers: Vec<(String, String)>,
/// Handsake subprotocols
subprotocols: Vec<String>,
}
impl ClientRequestBuilder {
/// Initializes an empty request builder
#[must_use]
pub const fn new(uri: Uri) -> Self {
Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
}
/// Adds (`key`, `value`) as an additional header to the handshake request
pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
where
K: Into<String>,
V: Into<String>,
{
self.additional_headers.push((key.into(), value.into()));
self
}
/// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
where
P: Into<String>,
{
self.subprotocols.push(protocol.into());
self
}
}
impl IntoClientRequest for ClientRequestBuilder {
fn into_client_request(self) -> Result<Request> {
let mut request = self.uri.into_client_request()?;
let headers = request.headers_mut();
for (k, v) in self.additional_headers {
let key = HeaderName::try_from(k)?;
let value = v.parse()?;
headers.append(key, value);
}
if !self.subprotocols.is_empty() {
let protocols = self.subprotocols.join(", ").parse()?;
headers.append("Sec-WebSocket-Protocol", protocols);
}
Ok(request)
}
}

View File

@ -59,6 +59,9 @@ pub enum Error {
/// UTF coding error.
#[error("UTF-8 encoding error")]
Utf8,
/// Attack attempt detected.
#[error("Attack attempt detected")]
AttackAttempt,
/// Invalid URL.
#[error("URL error: {0}")]
Url(#[from] UrlError),

View File

@ -20,7 +20,7 @@ pub struct HandshakeMachine<Stream> {
impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self {
HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) }
Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
@ -41,25 +41,31 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
trace!("Doing handshake round.");
match self.state {
HandshakeState::Reading(mut buf) => {
HandshakeState::Reading(mut buf, mut attack_check) => {
let read = buf.read_from(&mut self.stream).no_block()?;
match read {
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
Some(count) => {
attack_check.check_incoming_packet_size(count)?;
// TODO: this is slow for big headers with too many small packets.
// The parser has to be reworked in order to work on streams instead
// of buffers.
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf, attack_check),
..self
})
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}),
}
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
state: HandshakeState::Reading(buf, attack_check),
..self
})),
}
@ -119,7 +125,54 @@ pub trait TryParse: Sized {
#[derive(Debug)]
enum HandshakeState {
/// Reading data from the peer.
Reading(ReadBuffer),
Reading(ReadBuffer, AttackCheck),
/// Sending data to the peer.
Writing(Cursor<Vec<u8>>),
}
/// Attack mitigation. Contains counters needed to prevent DoS attacks
/// and reject valid but useless headers.
#[derive(Debug)]
pub(crate) struct AttackCheck {
/// Number of HTTP header successful reads (TCP packets).
number_of_packets: usize,
/// Total number of bytes in HTTP header.
number_of_bytes: usize,
}
impl AttackCheck {
/// Initialize attack checking for incoming buffer.
fn new() -> Self {
Self { number_of_packets: 0, number_of_bytes: 0 }
}
/// Check the size of an incoming packet. To be called immediately after `read()`
/// passing its returned bytes count as `size`.
fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
self.number_of_packets += 1;
self.number_of_bytes += size;
// TODO: these values are hardcoded. Instead of making them configurable,
// rework the way HTTP header is parsed to remove this check at all.
const MAX_BYTES: usize = 65536;
const MAX_PACKETS: usize = 512;
const MIN_PACKET_SIZE: usize = 128;
const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
if self.number_of_bytes > MAX_BYTES {
return Err(Error::AttackAttempt);
}
if self.number_of_packets > MAX_PACKETS {
return Err(Error::AttackAttempt);
}
if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD
&& self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes
{
return Err(Error::AttackAttempt);
}
Ok(())
}
}

View File

@ -39,7 +39,7 @@ pub use crate::{
#[cfg(feature = "handshake")]
pub use crate::{
client::{client, connect},
client::{client, connect, ClientRequestBuilder},
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
};

View File

@ -1,4 +1,4 @@
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use byteorder::{NetworkEndian, ReadBytesExt};
use log::*;
use std::{
borrow::Cow,
@ -108,8 +108,12 @@ impl FrameHeader {
output.write_all(&[one, two])?;
match lenfmt {
LengthFormat::U8(_) => (),
LengthFormat::U16 => output.write_u16::<NetworkEndian>(length as u16)?,
LengthFormat::U64 => output.write_u64::<NetworkEndian>(length)?,
LengthFormat::U16 => {
output.write_all(&(length as u16).to_be_bytes())?;
}
LengthFormat::U64 => {
output.write_all(&length.to_be_bytes())?;
}
}
if let Some(ref mask) = self.mask {
@ -295,7 +299,7 @@ impl Frame {
1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
_ => {
let mut data = self.payload;
let code = NetworkEndian::read_u16(&data[0..2]).into();
let code = u16::from_be_bytes([data[0], data[1]]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
Ok(Some(CloseFrame { code, reason: text.into() }))
@ -340,7 +344,7 @@ impl Frame {
pub fn close(msg: Option<CloseFrame>) -> Frame {
let payload = if let Some(CloseFrame { code, reason }) = msg {
let mut p = Vec::with_capacity(reason.as_bytes().len() + 2);
p.write_u16::<NetworkEndian>(code.into()).unwrap(); // can't fail
p.extend(u16::from(code).to_be_bytes());
p.extend_from_slice(reason.as_bytes());
p
} else {
@ -366,6 +370,8 @@ impl Frame {
impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use std::fmt::Write;
write!(
f,
"
@ -385,7 +391,10 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload.iter().map(|byte| format!("{:02x}", byte)).collect::<String>()
self.payload.iter().fold(String::new(), |mut output, byte| {
_ = write!(output, "{byte:02x}");
output
})
)
}
}

View File

@ -48,7 +48,7 @@ mod tests {
#[test]
fn test_apply_mask() {
let mask = [0x6d, 0xb6, 0xb2, 0x80];
let unmasked = vec![
let unmasked = [
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9,
0x12, 0x03,
];

View File

@ -13,13 +13,10 @@ use self::{
},
message::{IncompleteMessage, IncompleteMessageType},
};
use crate::{
error::{Error, ProtocolError, Result},
util::NonBlockingResult,
};
use crate::error::{Error, ProtocolError, Result};
use log::*;
use std::{
io::{ErrorKind as IoErrorKind, Read, Write},
io::{self, Read, Write},
mem::replace,
};
@ -42,19 +39,27 @@ pub struct WebSocketConfig {
/// to the underlying stream.
/// The default value is 128 KiB.
///
/// If set to `0` each message will be eagerly written to the underlying stream.
/// It is often more optimal to allow them to buffer a little, hence the default value.
///
/// Note: [`flush`](WebSocket::flush) will always fully write the buffer regardless.
pub write_buffer_size: usize,
/// The max size of the write buffer in bytes. Setting this can provide backpressure
/// in the case the write buffer is filling up due to write errors.
/// The default value is unlimited.
///
/// Note: Should always be set higher than [`write_buffer_size`](Self::write_buffer_size).
/// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
/// when writes to the underlying stream are failing. So the **write buffer can not
/// fill up if you are not observing write errors even if not flushing**.
///
/// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
/// and probably a little more depending on error handling strategy.
pub max_write_buffer_size: usize,
/// The maximum size of a message. `None` means no size limit. The default value is 64 MiB
/// The maximum size of an incoming message. `None` means no size limit. The default value is 64 MiB
/// which should be reasonably big for all normal use-cases but small enough to prevent
/// memory eating by a malicious user.
pub max_message_size: Option<usize>,
/// The maximum size of a single message frame. `None` means no size limit. The limit is for
/// The maximum size of a single incoming message frame. `None` means no size limit. The limit is for
/// frame payload NOT including the frame header. The default value is 16 MiB which should
/// be reasonably big for all normal use-cases but small enough to prevent memory eating
/// by a malicious user.
@ -81,6 +86,17 @@ impl Default for WebSocketConfig {
}
}
impl WebSocketConfig {
/// Panic if values are invalid.
pub(crate) fn assert_valid(&self) {
assert!(
self.max_write_buffer_size > self.write_buffer_size,
"WebSocketConfig::max_write_buffer_size must be greater than write_buffer_size, \
see WebSocketConfig docs`"
);
}
}
/// WebSocket input-output stream.
///
/// This is THE structure you want to create to be able to speak the WebSocket protocol.
@ -101,6 +117,9 @@ impl<Stream> WebSocket<Stream> {
/// Call this function if you're using Tungstenite as a part of a web framework
/// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn from_raw_socket(stream: Stream, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocket { socket: stream, context: WebSocketContext::new(role, config) }
}
@ -110,6 +129,9 @@ impl<Stream> WebSocket<Stream> {
/// Call this function if you're using Tungstenite as a part of a web framework
/// or together with an existing one. If you need an initial handshake, use
/// `connect()` or `accept()` functions of the crate to construct a websocket.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn from_partially_read(
stream: Stream,
part: Vec<u8>,
@ -132,6 +154,9 @@ impl<Stream> WebSocket<Stream> {
}
/// Change the configuration.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
self.context.set_config(set_func)
}
@ -285,22 +310,32 @@ pub struct WebSocketContext {
incomplete: Option<IncompleteMessage>,
/// Send in addition to regular messages E.g. "pong" or "close".
additional_send: Option<Frame>,
/// True indicates there is an additional message (like a pong)
/// that failed to flush previously and we should try again.
unflushed_additional: bool,
/// The configuration for the websocket session.
config: WebSocketConfig,
}
impl WebSocketContext {
/// Create a WebSocket context that manages a post-handshake stream.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn new(role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::new(), config.unwrap_or_default())
}
/// Create a WebSocket context that manages an post-handshake stream.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
Self::_new(role, FrameCodec::from_partially_read(part), config.unwrap_or_default())
}
fn _new(role: Role, mut frame: FrameCodec, config: WebSocketConfig) -> Self {
config.assert_valid();
frame.set_max_out_buffer_len(config.max_write_buffer_size);
frame.set_out_buffer_write_len(config.write_buffer_size);
Self {
@ -309,13 +344,18 @@ impl WebSocketContext {
state: WebSocketState::Active,
incomplete: None,
additional_send: None,
unflushed_additional: false,
config,
}
}
/// Change the configuration.
///
/// # Panics
/// Panics if config is invalid e.g. `max_write_buffer_size <= write_buffer_size`.
pub fn set_config(&mut self, set_func: impl FnOnce(&mut WebSocketConfig)) {
set_func(&mut self.config);
self.config.assert_valid();
self.frame.set_max_out_buffer_len(self.config.max_write_buffer_size);
self.frame.set_out_buffer_write_len(self.config.write_buffer_size);
}
@ -352,10 +392,16 @@ impl WebSocketContext {
self.state.check_not_terminated()?;
loop {
if self.additional_send.is_some() {
if self.additional_send.is_some() || self.unflushed_additional {
// Since we may get ping or close, we need to reply to the messages even during read.
// Thus we flush but ignore its blocking.
self.flush(stream).no_block()?;
match self.flush(stream) {
Ok(_) => {}
Err(Error::Io(err)) if err.kind() == io::ErrorKind::WouldBlock => {
// If blocked continue reading, but try again later
self.unflushed_additional = true;
}
Err(err) => return Err(err),
}
} else if self.role == Role::Server && !self.state.can_read() {
self.state = WebSocketState::Terminated;
return Err(Error::ConnectionClosed);
@ -423,7 +469,9 @@ impl WebSocketContext {
{
self._write(stream, None)?;
self.frame.write_out_buffer(stream)?;
Ok(stream.flush()?)
stream.flush()?;
self.unflushed_additional = false;
Ok(())
}
/// Writes any data in the out_buffer, `additional_send` and given `data`.
@ -456,7 +504,7 @@ impl WebSocketContext {
Ok(_) => true,
}
} else {
false
self.unflushed_additional
};
// If we're closing and there is nothing to send anymore, we should close the connection.
@ -735,7 +783,7 @@ impl<T> CheckConnectionReset for Result<T> {
fn check_connection_reset(self, state: WebSocketState) -> Self {
match self {
Err(Error::Io(io_error)) => Err({
if !state.can_read() && io_error.kind() == IoErrorKind::ConnectionReset {
if !state.can_read() && io_error.kind() == io::ErrorKind::ConnectionReset {
Error::ConnectionClosed
} else {
Error::Io(io_error)

View File

@ -70,7 +70,8 @@ mod encryption {
#[cfg(feature = "__rustls-tls")]
pub mod rustls {
use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned};
use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
use rustls_pki_types::ServerName;
use std::{
convert::TryFrom,
@ -105,36 +106,26 @@ mod encryption {
#[cfg(feature = "rustls-tls-native-roots")]
{
let native_certs = rustls_native_certs::load_native_certs()?;
let der_certs: Vec<Vec<u8>> =
native_certs.into_iter().map(|cert| cert.0).collect();
let total_number = der_certs.len();
let total_number = native_certs.len();
let (number_added, number_ignored) =
root_store.add_parsable_certificates(&der_certs);
root_store.add_parsable_certificates(native_certs);
log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
}
#[cfg(feature = "rustls-tls-webpki-roots")]
{
root_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
);
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}
Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth(),
)
}
};
let domain =
ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?;
let domain = ServerName::try_from(domain)
.map_err(|_| TlsError::InvalidDnsName)?
.to_owned();
let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
let stream = StreamOwned::new(client, socket);

130
tests/auto_pong_flush.rs Normal file
View File

@ -0,0 +1,130 @@
use std::{
io::{self, Cursor, Read, Write},
mem,
};
use tungstenite::{
protocol::frame::{
coding::{Control, OpCode},
Frame, FrameHeader,
},
Message, WebSocket,
};
const NUMBER_OF_FLUSHES_TO_GET_IT_TO_WORK: usize = 3;
/// `Read`/`Write` mock.
/// * Reads a single ping, then returns `WouldBlock` forever after.
/// * Writes work fine.
/// * Flush `WouldBlock` twice then works on the 3rd attempt.
#[derive(Debug, Default)]
struct MockWrite {
/// Data written, but not flushed.
written_data: Vec<u8>,
/// The latest successfully flushed data.
flushed_data: Vec<u8>,
write_calls: usize,
flush_calls: usize,
read_calls: usize,
}
impl Read for MockWrite {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
self.read_calls += 1;
if self.read_calls == 1 {
let ping = Frame::ping(vec![]);
let len = ping.len();
ping.format(&mut buf).expect("format failed");
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, "nothing else to read"))
}
}
}
impl Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_calls += 1;
self.written_data.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.flush_calls += 1;
if self.flush_calls % NUMBER_OF_FLUSHES_TO_GET_IT_TO_WORK == 0 {
mem::swap(&mut self.written_data, &mut self.flushed_data);
self.written_data.clear();
eprintln!("flush success");
Ok(())
} else {
eprintln!("flush would block");
Err(io::Error::new(io::ErrorKind::WouldBlock, "try again"))
}
}
}
/// Test for auto pong write & flushing behaviour.
///
/// In read-only/read-predominant usage auto pong responses should be written and flushed
/// even if WouldBlock errors are encountered.
#[test]
fn read_usage_auto_pong_flush() {
let mut ws =
WebSocket::from_raw_socket(MockWrite::default(), tungstenite::protocol::Role::Client, None);
// Receiving a ping should auto scheduled a pong on next read or write (but not written yet).
let msg = ws.read().unwrap();
assert!(matches!(msg, Message::Ping(_)), "Unexpected msg {:?}", msg);
assert_eq!(ws.get_ref().read_calls, 1);
assert!(ws.get_ref().written_data.is_empty(), "Unexpected {:?}", ws.get_ref());
assert!(ws.get_ref().flushed_data.is_empty(), "Unexpected {:?}", ws.get_ref());
// Next read fails as there is nothing else to read.
// This read call should have tried to write & flush a pong response, with the flush WouldBlock-ing
let next = ws.read().unwrap_err();
assert!(
matches!(next, tungstenite::Error::Io(ref err) if err.kind() == io::ErrorKind::WouldBlock),
"Unexpected read err {:?}",
next
);
assert_eq!(ws.get_ref().read_calls, 2);
assert!(!ws.get_ref().written_data.is_empty(), "Should have written a pong frame");
assert_eq!(ws.get_ref().write_calls, 1);
let pong_header =
FrameHeader::parse(&mut Cursor::new(&ws.get_ref().written_data)).unwrap().unwrap().0;
assert_eq!(pong_header.opcode, OpCode::Control(Control::Pong));
let written_data = ws.get_ref().written_data.clone();
assert_eq!(ws.get_ref().flush_calls, 1);
assert!(ws.get_ref().flushed_data.is_empty(), "Unexpected {:?}", ws.get_ref());
// Next read fails as before.
// This read call should try to flush the pong again, which again WouldBlock
let next = ws.read().unwrap_err();
assert!(
matches!(next, tungstenite::Error::Io(ref err) if err.kind() == io::ErrorKind::WouldBlock),
"Unexpected read err {:?}",
next
);
assert_eq!(ws.get_ref().read_calls, 3);
assert_eq!(ws.get_ref().write_calls, 1);
assert_eq!(ws.get_ref().flush_calls, 2);
assert!(ws.get_ref().flushed_data.is_empty(), "Unexpected {:?}", ws.get_ref());
// Next read fails as before.
// This read call should try to flush the pong again, 3rd flush attempt is the charm
let next = ws.read().unwrap_err();
assert!(
matches!(next, tungstenite::Error::Io(ref err) if err.kind() == io::ErrorKind::WouldBlock),
"Unexpected read err {:?}",
next
);
assert_eq!(ws.get_ref().read_calls, 4);
assert_eq!(ws.get_ref().write_calls, 1);
assert_eq!(ws.get_ref().flush_calls, 3);
assert!(ws.get_ref().flushed_data == written_data, "Unexpected {:?}", ws.get_ref());
// On following read calls no additional writes or flushes are necessary
ws.read().unwrap_err();
assert_eq!(ws.get_ref().read_calls, 5);
assert_eq!(ws.get_ref().write_calls, 1);
assert_eq!(ws.get_ref().flush_calls, 3);
}

96
tests/client_headers.rs Normal file
View File

@ -0,0 +1,96 @@
#![cfg(feature = "handshake")]
use http::Uri;
use std::{
net::TcpListener,
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{
accept_hdr, connect,
handshake::server::{Request, Response},
ClientRequestBuilder, Error, Message,
};
/// Test for write buffering and flushing behaviour.
#[test]
fn test_headers() {
env_logger::init();
let uri: Uri = "ws://127.0.0.1:3013/socket".parse().unwrap();
let token = "my_jwt_token";
let full_token = format!("Bearer {token}");
let sub_protocol = "my_sub_protocol";
let builder = ClientRequestBuilder::new(uri)
.with_header("Authorization", full_token.to_owned())
.with_sub_protocol(sub_protocol.to_owned());
spawn(|| {
sleep(Duration::from_secs(5));
println!("Unit test executed too long, perhaps stuck on WOULDBLOCK...");
exit(1);
});
let server = TcpListener::bind("127.0.0.1:3013").unwrap();
let client_thread = spawn(move || {
let (mut client, _) = connect(builder).unwrap();
client.send(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read().unwrap(); // receive close from server
assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
let callback = |req: &Request, mut response: Response| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
let authorization_header: String = "authorization".to_ascii_lowercase();
let web_socket_proto: String = "sec-websocket-protocol".to_ascii_lowercase();
for (ref header, value) in req.headers() {
println!("* {}: {}", header, value.to_str().unwrap());
if header.to_string() == authorization_header {
println!("Matching authorization header");
assert_eq!(header.to_string(), authorization_header);
assert_eq!(value.to_str().unwrap(), full_token);
} else if header.to_string() == web_socket_proto {
println!("Matching sec-websocket-protocol header");
assert_eq!(header.to_string(), web_socket_proto);
assert_eq!(value.to_str().unwrap(), sub_protocol);
// the server needs to respond with the same sub-protocol
response
.headers_mut()
.append("sec-websocket-protocol", sub_protocol.parse().unwrap());
}
}
Ok(response)
};
let client_handler = server.incoming().next().unwrap();
let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
client_handler.close(None).unwrap(); // send close to client
// This read should succeed even though we already initiated a close
let message = client_handler.read().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement
let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
drop(client_handler);
client_thread.join().unwrap();
}

View File

@ -10,7 +10,7 @@ use std::{
time::Duration,
};
use net2::TcpStreamExt;
use socket2::Socket;
use tungstenite::{accept, connect, stream::MaybeTlsStream, Error, Message, WebSocket};
use url::Url;
@ -19,7 +19,7 @@ type Sock = WebSocket<MaybeTlsStream<TcpStream>>;
fn do_test<CT, ST>(port: u16, client_task: CT, server_task: ST)
where
CT: FnOnce(Sock) + Send + 'static,
ST: FnOnce(WebSocket<TcpStream>),
ST: FnOnce(WebSocket<Socket>),
{
env_logger::try_init().ok();
@ -40,7 +40,7 @@ where
});
let client_handler = server.incoming().next().unwrap();
let client_handler = accept(client_handler.unwrap()).unwrap();
let client_handler = accept(Socket::from(client_handler.unwrap())).unwrap();
server_task(client_handler);

View File

@ -1,5 +1,7 @@
#![cfg(feature = "handshake")]
use std::net::TcpListener;
use std::thread::spawn;
use std::thread::{sleep, spawn};
use std::time::Duration;
use tungstenite::error::{Error, ProtocolError, SubProtocolError};
use tungstenite::handshake::client::generate_key;
use tungstenite::handshake::server::{Request, Response};
@ -35,7 +37,6 @@ fn server_thread(port: u16, server_subprotocols: Option<Vec<String>>) {
spawn(move || {
let server = TcpListener::bind(("127.0.0.1", port))
.expect("Can't listen, is this port already in use?");
let client_handler = server.incoming().next().unwrap();
let callback = |_request: &Request, mut response: Response| {
if let Some(subprotocols) = server_subprotocols {
@ -45,13 +46,16 @@ fn server_thread(port: u16, server_subprotocols: Option<Vec<String>>) {
Ok(response)
};
let _client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
let client_handler = server.incoming().next().unwrap();
let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
client_handler.close(None).unwrap();
});
}
#[test]
fn test_server_send_no_subprotocol() {
server_thread(3012, None);
sleep(Duration::from_secs(1));
let err =
connect(create_http_request("ws://127.0.0.1:3012", Some(vec!["my-sub-protocol".into()])))
@ -68,6 +72,7 @@ fn test_server_send_no_subprotocol() {
#[test]
fn test_server_sent_subprotocol_none_requested() {
server_thread(3013, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let err = connect(create_http_request("ws://127.0.0.1:3013", None)).unwrap_err();
@ -82,6 +87,7 @@ fn test_server_sent_subprotocol_none_requested() {
#[test]
fn test_invalid_subprotocol() {
server_thread(3014, Some(vec!["invalid-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let err = connect(create_http_request(
"ws://127.0.0.1:3014",
@ -100,7 +106,7 @@ fn test_invalid_subprotocol() {
#[test]
fn test_request_multiple_subprotocols() {
server_thread(3015, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let (_, response) = connect(create_http_request(
"ws://127.0.0.1:3015",
Some(vec![
@ -120,6 +126,7 @@ fn test_request_multiple_subprotocols() {
#[test]
fn test_request_single_subprotocol() {
server_thread(3016, Some(vec!["my-sub-protocol".to_string()]));
sleep(Duration::from_secs(1));
let (_, response) = connect(create_http_request(
"ws://127.0.0.1:3016",