Merge pull request #379 from snapview/CVE-2023-43669
Quick-and-dirty fix for CVE-2023-43669
This commit is contained in:
commit
219075edaa
|
@ -9,7 +9,7 @@ readme = "README.md"
|
|||
homepage = "https://github.com/snapview/tungstenite-rs"
|
||||
documentation = "https://docs.rs/tungstenite/0.20.0"
|
||||
repository = "https://github.com/snapview/tungstenite-rs"
|
||||
version = "0.20.0"
|
||||
version = "0.20.1"
|
||||
edition = "2018"
|
||||
rust-version = "1.51"
|
||||
include = ["benches/**/*", "src/**/*", "examples/**/*", "LICENSE-*", "README.md", "CHANGELOG.md"]
|
||||
|
|
|
@ -59,6 +59,9 @@ pub enum Error {
|
|||
/// UTF coding error.
|
||||
#[error("UTF-8 encoding error")]
|
||||
Utf8,
|
||||
/// Attack attempt detected.
|
||||
#[error("Attack attempt detected")]
|
||||
AttackAttempt,
|
||||
/// Invalid URL.
|
||||
#[error("URL error: {0}")]
|
||||
Url(#[from] UrlError),
|
||||
|
|
|
@ -20,7 +20,7 @@ pub struct HandshakeMachine<Stream> {
|
|||
impl<Stream> HandshakeMachine<Stream> {
|
||||
/// Start reading data from the peer.
|
||||
pub fn start_read(stream: Stream) -> Self {
|
||||
HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) }
|
||||
Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
|
||||
}
|
||||
/// Start writing data to the peer.
|
||||
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
|
||||
|
@ -41,25 +41,31 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
|
|||
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
|
||||
trace!("Doing handshake round.");
|
||||
match self.state {
|
||||
HandshakeState::Reading(mut buf) => {
|
||||
HandshakeState::Reading(mut buf, mut attack_check) => {
|
||||
let read = buf.read_from(&mut self.stream).no_block()?;
|
||||
match read {
|
||||
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
|
||||
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
|
||||
buf.advance(size);
|
||||
RoundResult::StageFinished(StageResult::DoneReading {
|
||||
result: obj,
|
||||
stream: self.stream,
|
||||
tail: buf.into_vec(),
|
||||
Some(count) => {
|
||||
attack_check.check_incoming_packet_size(count)?;
|
||||
// TODO: this is slow for big headers with too many small packets.
|
||||
// The parser has to be reworked in order to work on streams instead
|
||||
// of buffers.
|
||||
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
|
||||
buf.advance(size);
|
||||
RoundResult::StageFinished(StageResult::DoneReading {
|
||||
result: obj,
|
||||
stream: self.stream,
|
||||
tail: buf.into_vec(),
|
||||
})
|
||||
} else {
|
||||
RoundResult::Incomplete(HandshakeMachine {
|
||||
state: HandshakeState::Reading(buf, attack_check),
|
||||
..self
|
||||
})
|
||||
})
|
||||
} else {
|
||||
RoundResult::Incomplete(HandshakeMachine {
|
||||
state: HandshakeState::Reading(buf),
|
||||
..self
|
||||
})
|
||||
}),
|
||||
}
|
||||
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
|
||||
state: HandshakeState::Reading(buf),
|
||||
state: HandshakeState::Reading(buf, attack_check),
|
||||
..self
|
||||
})),
|
||||
}
|
||||
|
@ -119,7 +125,54 @@ pub trait TryParse: Sized {
|
|||
#[derive(Debug)]
|
||||
enum HandshakeState {
|
||||
/// Reading data from the peer.
|
||||
Reading(ReadBuffer),
|
||||
Reading(ReadBuffer, AttackCheck),
|
||||
/// Sending data to the peer.
|
||||
Writing(Cursor<Vec<u8>>),
|
||||
}
|
||||
|
||||
/// Attack mitigation. Contains counters needed to prevent DoS attacks
|
||||
/// and reject valid but useless headers.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AttackCheck {
|
||||
/// Number of HTTP header successful reads (TCP packets).
|
||||
number_of_packets: usize,
|
||||
/// Total number of bytes in HTTP header.
|
||||
number_of_bytes: usize,
|
||||
}
|
||||
|
||||
impl AttackCheck {
|
||||
/// Initialize attack checking for incoming buffer.
|
||||
fn new() -> Self {
|
||||
Self { number_of_packets: 0, number_of_bytes: 0 }
|
||||
}
|
||||
|
||||
/// Check the size of an incoming packet. To be called immediately after `read()`
|
||||
/// passing its returned bytes count as `size`.
|
||||
fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
|
||||
self.number_of_packets += 1;
|
||||
self.number_of_bytes += size;
|
||||
|
||||
// TODO: these values are hardcoded. Instead of making them configurable,
|
||||
// rework the way HTTP header is parsed to remove this check at all.
|
||||
const MAX_BYTES: usize = 65536;
|
||||
const MAX_PACKETS: usize = 512;
|
||||
const MIN_PACKET_SIZE: usize = 128;
|
||||
const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
|
||||
|
||||
if self.number_of_bytes > MAX_BYTES {
|
||||
return Err(Error::AttackAttempt);
|
||||
}
|
||||
|
||||
if self.number_of_packets > MAX_PACKETS {
|
||||
return Err(Error::AttackAttempt);
|
||||
}
|
||||
|
||||
if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD {
|
||||
if self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes {
|
||||
return Err(Error::AttackAttempt);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue