Edition 2018, formatting, clippy fixes

This commit is contained in:
Artem Vorotnikov 2019-08-26 20:00:41 +03:00
parent b40256eedd
commit cbf80ecc76
No known key found for this signature in database
GPG Key ID: E0148C3F2FBB7A20
25 changed files with 576 additions and 551 deletions

View File

@ -10,6 +10,7 @@ homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.9.1"
repository = "https://github.com/snapview/tungstenite-rs"
version = "0.9.1"
edition = "2018"
[features]
default = ["tls"]

View File

@ -1,18 +1,12 @@
#[macro_use] extern crate log;
extern crate env_logger;
extern crate tungstenite;
extern crate url;
use log::*;
use url::Url;
use tungstenite::{connect, Error, Result, Message};
use tungstenite::{connect, Error, Message, Result};
const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
)?;
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
let msg = socket.read_message()?;
socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap())
@ -20,7 +14,11 @@ fn get_case_count() -> Result<u32> {
fn update_reports() -> Result<()> {
let (mut socket, _) = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
Url::parse(&format!(
"ws://localhost:9001/updateReports?agent={}",
AGENT
))
.unwrap(),
)?;
socket.close(None)?;
Ok(())
@ -28,19 +26,18 @@ fn update_reports() -> Result<()> {
fn run_test(case: u32) -> Result<()> {
info!("Running test case {}", case);
let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap();
let case_url = Url::parse(&format!(
"ws://localhost:9001/runCase?case={}&agent={}",
case, AGENT
))
.unwrap();
let (mut socket, _) = connect(case_url)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |
msg @ Message::Binary(_) => {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.write_message(msg)?;
}
Message::Ping(_) |
Message::Pong(_) |
Message::Close(_) => {}
Message::Ping(_) | Message::Pong(_) | Message::Close(_) => {}
}
}
}
@ -53,12 +50,13 @@ fn main() {
for case in 1..(total + 1) {
if let Err(e) = run_test(case) {
match e {
Error::Protocol(_) => { }
err => { warn!("test: {}", err); }
Error::Protocol(_) => {}
err => {
warn!("test: {}", err);
}
}
}
}
update_reports().unwrap();
}

View File

@ -1,12 +1,9 @@
#[macro_use] extern crate log;
extern crate env_logger;
extern crate tungstenite;
use std::net::{TcpListener, TcpStream};
use std::thread::spawn;
use tungstenite::{accept, HandshakeError, Error, Result, Message};
use log::*;
use tungstenite::handshake::HandshakeRole;
use tungstenite::{accept, Error, HandshakeError, Message, Result};
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
match err {
@ -19,13 +16,10 @@ fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).map_err(must_not_block)?;
loop {
match socket.read_message()? {
msg @ Message::Text(_) |
msg @ Message::Binary(_) => {
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
socket.write_message(msg)?;
}
Message::Ping(_) |
Message::Pong(_) |
Message::Close(_) => {}
Message::Ping(_) | Message::Pong(_) | Message::Close(_) => {}
}
}
}
@ -36,14 +30,12 @@ fn main() {
let server = TcpListener::bind("127.0.0.1:9002").unwrap();
for stream in server.incoming() {
spawn(move || {
match stream {
Ok(stream) => match handle_client(stream) {
Ok(_) => (),
Err(e) => warn!("Error in client: {}", e),
},
Err(e) => warn!("Error accepting stream: {}", e),
}
spawn(move || match stream {
Ok(stream) => match handle_client(stream) {
Ok(_) => (),
Err(e) => warn!("Error in client: {}", e),
},
Err(e) => warn!("Error accepting stream: {}", e),
});
}
}

View File

@ -1,10 +1,8 @@
extern crate tungstenite;
use std::thread::spawn;
use std::net::TcpListener;
use std::thread::spawn;
use tungstenite::accept_hdr;
use tungstenite::handshake::server::{Request, ErrorResponse};
use tungstenite::handshake::server::{ErrorResponse, Request};
use tungstenite::http::StatusCode;
fn main() {

View File

@ -1,15 +1,11 @@
extern crate tungstenite;
extern crate url;
extern crate env_logger;
use tungstenite::{connect, Message};
use url::Url;
use tungstenite::{Message, connect};
fn main() {
env_logger::init();
let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect");
let (mut socket, response) =
connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect");
println!("Connected to the server");
println!("Response HTTP code: {}", response.code);
@ -18,11 +14,12 @@ fn main() {
println!("* {}", header);
}
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
socket
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
loop {
let msg = socket.read_message().expect("Error reading message");
println!("Received: {}", msg);
}
// socket.close(None);
}

View File

@ -1,8 +1,5 @@
extern crate tungstenite;
extern crate env_logger;
use std::thread::spawn;
use std::net::TcpListener;
use std::thread::spawn;
use tungstenite::accept_hdr;
use tungstenite::handshake::server::Request;
@ -23,7 +20,10 @@ fn main() {
// Let's add an additional header to our response to the client.
let extra_headers = vec![
(String::from("MyCustomHeader"), String::from(":)")),
(String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")),
(
String::from("SOME_TUNGSTENITE_HEADER"),
String::from("header_value"),
),
];
Ok(Some(extra_headers))
};

View File

@ -1,4 +1,3 @@
[package]
name = "tungstenite-fuzz"
version = "0.0.1"

View File

@ -1,36 +1,40 @@
//! Methods to connect to an WebSocket as a client.
use std::net::{TcpStream, SocketAddr, ToSocketAddrs};
use std::result::Result as StdResult;
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;
use log::*;
use url::Url;
use handshake::client::Response;
use protocol::WebSocketConfig;
use crate::handshake::client::Response;
use crate::protocol::WebSocketConfig;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
mod encryption {
use std::net::TcpStream;
use native_tls::{TlsConnector, HandshakeError as TlsHandshakeError};
pub use native_tls::TlsStream;
use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector};
use std::net::TcpStream;
pub use stream::Stream as StreamSwitcher;
pub use crate::stream::Stream as StreamSwitcher;
/// TCP stream switcher (plain/TLS).
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
use stream::Mode;
use error::Result;
use crate::error::Result;
use crate::stream::Mode;
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
let connector = TlsConnector::builder().build()?;
connector.connect(domain, stream)
connector
.connect(domain, stream)
.map_err(|e| match e {
TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::WouldBlock(_) => panic!("Bug: TLS handshake not blocked"),
TlsHandshakeError::WouldBlock(_) => {
panic!("Bug: TLS handshake not blocked")
}
})
.map(StreamSwitcher::Tls)
}
@ -38,12 +42,12 @@ mod encryption {
}
}
#[cfg(not(feature="tls"))]
#[cfg(not(feature = "tls"))]
mod encryption {
use std::net::TcpStream;
use stream::Mode;
use error::{Error, Result};
use stream::Mode;
/// TLS support is nod compiled in, this is just standard `TcpStream`.
pub type AutoStream = TcpStream;
@ -56,15 +60,14 @@ mod encryption {
}
}
pub use self::encryption::AutoStream;
use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;
use protocol::WebSocket;
use handshake::HandshakeError;
use handshake::client::{ClientHandshake, Request};
use stream::{NoDelay, Mode};
use error::{Error, Result};
use crate::error::{Error, Result};
use crate::handshake::client::{ClientHandshake, Request};
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay};
/// Connect to the given WebSocket in blocking mode.
///
@ -83,13 +86,17 @@ use error::{Error, Result};
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
request: Req,
config: Option<WebSocketConfig>
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into();
let mode = url_mode(&request.url)?;
let host = request.url.host()
let host = request
.url
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request.url.port_or_known_default()
let port = request
.url
.port_or_known_default()
.ok_or_else(|| Error::Url("No port number in the URL".into()))?;
let addrs;
let addr;
@ -109,11 +116,10 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
};
let mut stream = connect_to_some(addrs, &request.url, mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config)
.map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
}
/// Connect to the given WebSocket in blocking mode.
@ -128,19 +134,21 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
-> Result<(WebSocket<AutoStream>, Response)>
{
pub fn connect<'t, Req: Into<Request<'t>>>(
request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None)
}
fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> {
let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let domain = url
.host_str()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream)
return Ok(stream);
}
}
}
@ -155,7 +163,7 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into()))
_ => Err(Error::Url("URL scheme not supported".into())),
}
}
@ -182,8 +190,10 @@ where
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(request: Req, stream: Stream)
-> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
pub fn client<'t, Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,

View File

@ -11,9 +11,9 @@ use std::string;
use httparse;
use protocol::Message;
use crate::protocol::Message;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
pub mod tls {
//! TLS error wrapper module, feature-gated.
pub use native_tls::Error;
@ -41,7 +41,7 @@ pub enum Error {
AlreadyClosed,
/// Input-output error
Io(io::Error),
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
/// TLS error
Tls(tls::Error),
/// Buffer capacity exhausted
@ -64,7 +64,7 @@ impl fmt::Display for Error {
Error::ConnectionClosed => write!(f, "Connection closed normally"),
Error::AlreadyClosed => write!(f, "Trying to work with closed connection"),
Error::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
@ -82,7 +82,7 @@ impl ErrorTrait for Error {
Error::ConnectionClosed => "A close handshake is performed",
Error::AlreadyClosed => "Trying to read or write after getting close notification",
Error::Io(ref err) => err.description(),
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
Error::Tls(ref err) => err.description(),
Error::Capacity(ref msg) => msg.borrow(),
Error::Protocol(ref msg) => msg.borrow(),
@ -112,7 +112,7 @@ impl From<string::FromUtf8Error> for Error {
}
}
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {
Error::Tls(err)

View File

@ -4,17 +4,15 @@ use std::borrow::Cow;
use std::io::{Read, Write};
use std::marker::PhantomData;
use base64;
use httparse::Status;
use httparse;
use rand;
use log::*;
use url::Url;
use error::{Error, Result};
use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Client request.
#[derive(Debug)]
@ -80,26 +78,32 @@ impl<S: Read + Write> ClientHandshake<S> {
pub fn start(
stream: S,
request: Request,
config: Option<WebSocketConfig>
config: Option<WebSocketConfig>,
) -> MidHandshake<Self> {
let key = generate_key();
let machine = {
let mut req = Vec::new();
write!(req, "\
GET {path} HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(), path = request.get_path(), key = key).unwrap();
write!(
req,
"\
GET {path} HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n",
host = request.get_host(),
path = request.get_path(),
key = key
)
.unwrap();
if let Some(eh) = request.extra_headers {
for (k, v) in eh {
write!(req, "{}: {}\r\n", k, v).unwrap();
writeln!(req, "{}: {}\r", k, v).unwrap();
}
}
write!(req, "\r\n").unwrap();
writeln!(req, "\r").unwrap();
HandshakeMachine::start_write(stream, req)
};
@ -113,7 +117,10 @@ impl<S: Read + Write> ClientHandshake<S> {
};
trace!("Client handshake initiated.");
MidHandshake { role: client, machine }
MidHandshake {
role: client,
machine,
}
}
}
@ -121,22 +128,23 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
type IncomingData = Response;
type InternalStream = S;
type FinalResult = (WebSocket<S>, Response);
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish {
StageResult::DoneWriting(stream) => {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
StageResult::DoneReading { stream, result, tail, } => {
StageResult::DoneReading {
stream,
result,
tail,
} => {
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
let websocket = WebSocket::from_partially_read(
stream,
tail,
Role::Client,
self.config.clone(),
);
let websocket =
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
ProcessingResult::Done((websocket, result))
}
})
@ -161,22 +169,37 @@ impl VerifyData {
// header field contains a value that is not an ASCII case-
// insensitive match for the value "websocket", the client MUST
// _Fail the WebSocket Connection_. (RFC 6455)
if !response.headers.header_is_ignore_case("Upgrade", "websocket") {
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
if !response
.headers
.header_is_ignore_case("Upgrade", "websocket")
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(),
));
}
// 3. If the response lacks a |Connection| header field or the
// |Connection| header field doesn't contain a token that is an
// ASCII case-insensitive match for the value "Upgrade", the client
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response.headers.header_is_ignore_case("Connection", "Upgrade") {
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
if !response
.headers
.header_is_ignore_case("Connection", "Upgrade")
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(),
));
}
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
// the |Sec-WebSocket-Accept| contains a value other than the
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
// Connection_. (RFC 6455)
if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) {
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
if !response
.headers
.header_is("Sec-WebSocket-Accept", &self.accept_key)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
));
}
// 5. If the response includes a |Sec-WebSocket-Extensions| header
// field and this header field indicates the use of an extension
@ -219,7 +242,9 @@ impl TryParse for Response {
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"),
@ -238,8 +263,8 @@ fn generate_key() -> String {
#[cfg(test)]
mod tests {
use super::{Response, generate_key};
use super::super::machine::TryParse;
use super::{generate_key, Response};
#[test]
fn random_keys() {
@ -262,6 +287,9 @@ mod tests {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200);
assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
assert_eq!(
resp.headers.find_first("Content-Type"),
Some(&b"text/html"[..])
);
}
}

View File

@ -1,13 +1,13 @@
//! HTTP Request and response header handling.
use std::str::from_utf8;
use std::slice;
use std::str::from_utf8;
use httparse;
use httparse::Status;
use error::Result;
use super::machine::TryParse;
use crate::error::Result;
/// Limit for the number of header lines.
pub const MAX_HEADERS: usize = 124;
@ -19,7 +19,6 @@ pub struct Headers {
}
impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next()
@ -29,7 +28,7 @@ impl Headers {
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter {
name,
iter: self.data.iter()
iter: self.data.iter(),
}
}
@ -42,7 +41,8 @@ impl Headers {
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name).ok_or(())
self.find_first(name)
.ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
@ -52,7 +52,6 @@ impl Headers {
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter()
}
}
/// The iterator over headers.
@ -67,14 +66,13 @@ impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value)
return Some(value);
}
}
None
}
}
/// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers.
@ -94,9 +92,10 @@ impl TryParse for Headers {
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
data: raw
.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
})
}
}
@ -104,13 +103,12 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
#[cfg(test)]
mod tests {
use super::Headers;
use super::super::machine::TryParse;
use super::Headers;
#[test]
fn headers() {
const DATA: &'static [u8] =
b"Host: foo.com\r\n\
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n";
@ -126,8 +124,7 @@ mod tests {
#[test]
fn headers_iter() {
const DATA: &'static [u8] =
b"Host: foo.com\r\n\
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
@ -142,12 +139,10 @@ mod tests {
#[test]
fn headers_incomplete() {
const DATA: &'static [u8] =
b"Host: foo.com\r\n\
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n";
let hdr = Headers::try_parse(DATA).unwrap();
assert!(hdr.is_none());
}
}

View File

@ -1,9 +1,10 @@
use std::io::{Cursor, Read, Write};
use bytes::Buf;
use log::*;
use std::io::{Cursor, Read, Write};
use crate::error::{Error, Result};
use crate::util::NonBlockingResult;
use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result};
use util::NonBlockingResult;
/// A generic handshake state machine.
#[derive(Debug)]
@ -43,16 +44,16 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
trace!("Doing handshake round.");
match self.state {
HandshakeState::Reading(mut buf) => {
let read = buf.prepare_reserve(MIN_READ)
let read = buf
.prepare_reserve(MIN_READ)
.with_limit(usize::max_value()) // TODO limit size
.map_err(|_| Error::Capacity("Header too long".into()))?
.read_from(&mut self.stream).no_block()?;
.read_from(&mut self.stream)
.no_block()?;
match read {
Some(0) => {
Err(Error::Protocol("Handshake not finished".into()))
}
Some(_) => {
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
Some(0) => Err(Error::Protocol("Handshake not finished".into())),
Some(_) => Ok(
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
@ -64,14 +65,12 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
state: HandshakeState::Reading(buf),
..self
})
})
}
None => {
Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
}))
}
},
),
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})),
}
}
HandshakeState::Writing(mut buf) => {
@ -113,7 +112,11 @@ pub enum RoundResult<Obj, Stream> {
#[derive(Debug)]
pub enum StageResult<Obj, Stream> {
/// Reading round finished.
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
DoneReading {
result: Obj,
stream: Stream,
tail: Vec<u8>,
},
/// Writing round finished.
DoneWriting(Stream),
}

View File

@ -1,7 +1,7 @@
//! WebSocket handshake control.
pub mod headers;
pub mod client;
pub mod headers;
pub mod server;
mod machine;
@ -11,10 +11,10 @@ use std::fmt;
use std::io::{Read, Write};
use base64;
use sha1::{Sha1, Digest};
use sha1::{Digest, Sha1};
use error::Error;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
use crate::error::Error;
/// A WebSocket handshake.
#[derive(Debug)]
@ -30,15 +30,16 @@ impl<Role: HandshakeRole> MidHandshake<Role> {
loop {
mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
return Err(HandshakeError::Interrupted(MidHandshake {
machine: m,
..self
}))
}
RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(result) => return Ok(result),
}
}
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(result) => return Ok(result),
},
}
}
}
@ -94,8 +95,10 @@ pub trait HandshakeRole {
#[doc(hidden)]
type FinalResult;
#[doc(hidden)]
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
}
/// Stage processing result.
@ -124,8 +127,9 @@ mod tests {
#[test]
fn key_conversion() {
// example from RFC 6455
assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
assert_eq!(
convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(),
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
);
}
}

View File

@ -5,15 +5,15 @@ use std::io::{Read, Write};
use std::marker::PhantomData;
use std::result::Result as StdResult;
use httparse;
use httparse::Status;
use http::StatusCode;
use httparse::Status;
use log::*;
use error::{Error, Result};
use protocol::{WebSocket, WebSocketConfig, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig};
/// Request from the client.
#[derive(Debug)]
@ -27,14 +27,16 @@ pub struct Request {
impl Request {
/// Reply to the response.
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key")
let key = self
.headers
.find_first("Sec-WebSocket-Key")
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
let mut reply = format!(
"\
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n",
HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n",
convert_key(key)?
);
add_headers(&mut reply, extra_headers);
@ -45,13 +47,12 @@ impl Request {
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) {
if let Some(eh) = extra_headers {
for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap();
writeln!(reply, "{}: {}\r", k, v).unwrap();
}
}
write!(reply, "\r\n").unwrap();
writeln!(reply, "\r").unwrap();
}
impl TryParse for Request {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
@ -69,11 +70,13 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
return Err(Error::Protocol("Method is not GET".into()));
}
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
return Err(Error::Protocol(
"HTTP version should be 1.1 or higher".into(),
));
}
Ok(Request {
path: raw.path.expect("Bug: no path in header").into(),
headers: Headers::from_httparse(raw.headers)?
headers: Headers::from_httparse(raw.headers)?,
})
}
}
@ -115,7 +118,10 @@ pub trait Callback: Sized {
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>;
}
impl<F> Callback for F where F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
impl<F> Callback for F
where
F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>,
{
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
self(request)
}
@ -160,7 +166,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
callback: Some(callback),
config,
error_code: None,
_marker: PhantomData
_marker: PhantomData,
},
}
}
@ -171,13 +177,18 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
type InternalStream = S;
type FinalResult = WebSocket<S>;
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
{
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
Ok(match finish {
StageResult::DoneReading { stream, result, tail } => {
StageResult::DoneReading {
stream,
result,
tail,
} => {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()))
return Err(Error::Protocol("Junk after client request".into()));
}
let callback_result = if let Some(callback) = self.callback.take() {
@ -192,8 +203,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}
Err(ErrorResponse { error_code, headers, body }) => {
self.error_code= Some(error_code.as_u16());
Err(ErrorResponse {
error_code,
headers,
body,
}) => {
self.error_code = Some(error_code.as_u16());
let mut response = format!(
"HTTP/1.1 {} {}\r\n",
error_code.as_str(),
@ -214,11 +229,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
return Err(Error::Http(err));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.clone(),
);
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
ProcessingResult::Done(websocket)
}
}
@ -228,9 +239,9 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
#[cfg(test)]
mod tests {
use super::Request;
use super::super::machine::TryParse;
use super::super::client::Response;
use super::super::machine::TryParse;
use super::Request;
#[test]
fn request_parsing() {
@ -253,13 +264,19 @@ mod tests {
let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
let _ = req.reply(None).unwrap();
let extra_headers = Some(vec![(String::from("MyCustomHeader"),
String::from("MyCustomValue")),
(String::from("MyVersion"),
String::from("LOL"))]);
let extra_headers = Some(vec![
(
String::from("MyCustomHeader"),
String::from("MyCustomValue"),
),
(String::from("MyVersion"), String::from("LOL")),
]);
let reply = req.reply(extra_headers).unwrap();
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref()));
assert_eq!(
req.headers.find_first("MyCustomHeader"),
Some(b"MyCustomValue".as_ref())
);
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
}
}

View File

@ -3,39 +3,29 @@
missing_docs,
missing_copy_implementations,
missing_debug_implementations,
trivial_casts, trivial_numeric_casts,
trivial_casts,
trivial_numeric_casts,
unstable_features,
unused_must_use,
unused_mut,
unused_imports,
unused_import_braces)]
unused_import_braces
)]
#[macro_use] extern crate log;
extern crate base64;
extern crate byteorder;
extern crate bytes;
extern crate httparse;
extern crate input_buffer;
extern crate rand;
extern crate sha1;
extern crate url;
extern crate utf8;
#[cfg(feature="tls")] extern crate native_tls;
pub use http;
pub extern crate http;
pub mod error;
pub mod protocol;
pub mod client;
pub mod server;
pub mod error;
pub mod handshake;
pub mod protocol;
pub mod server;
pub mod stream;
pub mod util;
pub use client::{connect, client};
pub use server::{accept, accept_hdr};
pub use error::{Error, Result};
pub use protocol::{WebSocket, Message};
pub use handshake::HandshakeError;
pub use handshake::client::ClientHandshake;
pub use handshake::server::ServerHandshake;
pub use crate::client::{client, connect};
pub use crate::error::{Error, Result};
pub use crate::handshake::client::ClientHandshake;
pub use crate::handshake::server::ServerHandshake;
pub use crate::handshake::HandshakeError;
pub use crate::protocol::{Message, WebSocket};
pub use crate::server::{accept, accept_hdr};

View File

@ -1,7 +1,7 @@
//! Various codes defined in RFC 6455.
use std::convert::{From, Into};
use std::fmt;
use std::convert::{Into, From};
/// WebSocket message opcode as in RFC 6455.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -42,8 +42,8 @@ impl fmt::Display for Data {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Data::Continue => write!(f, "CONTINUE"),
Data::Text => write!(f, "TEXT"),
Data::Binary => write!(f, "BINARY"),
Data::Text => write!(f, "TEXT"),
Data::Binary => write!(f, "BINARY"),
Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x),
}
}
@ -53,8 +53,8 @@ impl fmt::Display for Control {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Control::Close => write!(f, "CLOSE"),
Control::Ping => write!(f, "PING"),
Control::Pong => write!(f, "PONG"),
Control::Ping => write!(f, "PING"),
Control::Pong => write!(f, "PONG"),
Control::Reserved(x) => write!(f, "RESERVED_CONTROL_{}", x),
}
}
@ -71,18 +71,18 @@ impl fmt::Display for OpCode {
impl Into<u8> for OpCode {
fn into(self) -> u8 {
use self::Data::{Continue, Text, Binary};
use self::Control::{Close, Ping, Pong};
use self::Data::{Binary, Continue, Text};
use self::OpCode::*;
match self {
Data(Continue) => 0,
Data(Text) => 1,
Data(Binary) => 2,
Data(Text) => 1,
Data(Binary) => 2,
Data(self::Data::Reserved(i)) => i,
Control(Close) => 8,
Control(Ping) => 9,
Control(Pong) => 10,
Control(Ping) => 9,
Control(Pong) => 10,
Control(self::Control::Reserved(i)) => i,
}
}
@ -90,19 +90,19 @@ impl Into<u8> for OpCode {
impl From<u8> for OpCode {
fn from(byte: u8) -> OpCode {
use self::Data::{Continue, Text, Binary};
use self::Control::{Close, Ping, Pong};
use self::Data::{Binary, Continue, Text};
use self::OpCode::*;
match byte {
0 => Data(Continue),
1 => Data(Text),
2 => Data(Binary),
i @ 3 ... 7 => Data(self::Data::Reserved(i)),
8 => Control(Close),
9 => Control(Ping),
10 => Control(Pong),
i @ 11 ... 15 => Control(self::Control::Reserved(i)),
_ => panic!("Bug: OpCode out of range"),
0 => Data(Continue),
1 => Data(Text),
2 => Data(Binary),
i @ 3..=7 => Data(self::Data::Reserved(i)),
8 => Control(Close),
9 => Control(Ping),
10 => Control(Pong),
i @ 11..=15 => Control(self::Control::Reserved(i)),
_ => panic!("Bug: OpCode out of range"),
}
}
}
@ -183,13 +183,13 @@ pub enum CloseCode {
impl CloseCode {
/// Check if this CloseCode is allowed.
pub fn is_allowed(&self) -> bool {
match *self {
Bad(_) => false,
pub fn is_allowed(self) -> bool {
match self {
Bad(_) => false,
Reserved(_) => false,
Status => false,
Abnormal => false,
Tls => false,
Status => false,
Abnormal => false,
Tls => false,
_ => true,
}
}
@ -205,24 +205,24 @@ impl fmt::Display for CloseCode {
impl<'t> Into<u16> for &'t CloseCode {
fn into(self) -> u16 {
match *self {
Normal => 1000,
Away => 1001,
Protocol => 1002,
Unsupported => 1003,
Status => 1005,
Abnormal => 1006,
Invalid => 1007,
Policy => 1008,
Size => 1009,
Extension => 1010,
Error => 1011,
Restart => 1012,
Again => 1013,
Tls => 1015,
Reserved(code) => code,
Iana(code) => code,
Library(code) => code,
Bad(code) => code,
Normal => 1000,
Away => 1001,
Protocol => 1002,
Unsupported => 1003,
Status => 1005,
Abnormal => 1006,
Invalid => 1007,
Policy => 1008,
Size => 1009,
Extension => 1010,
Error => 1011,
Restart => 1012,
Again => 1013,
Tls => 1015,
Reserved(code) => code,
Iana(code) => code,
Library(code) => code,
Bad(code) => code,
}
}
}
@ -250,11 +250,11 @@ impl From<u16> for CloseCode {
1012 => Restart,
1013 => Again,
1015 => Tls,
1...999 => Bad(code),
1000...2999 => Reserved(code),
3000...3999 => Iana(code),
4000...4999 => Library(code),
_ => Bad(code)
1..=999 => Bad(code),
1016..=2999 => Reserved(code),
3000..=3999 => Iana(code),
4000..=4999 => Library(code),
_ => Bad(code),
}
}
}

View File

@ -1,14 +1,15 @@
use std::fmt;
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
use log::*;
use std::borrow::Cow;
use std::io::{Cursor, Read, Write, ErrorKind};
use std::default::Default;
use std::string::{String, FromUtf8Error};
use std::fmt;
use std::io::{Cursor, ErrorKind, Read, Write};
use std::result::Result as StdResult;
use byteorder::{ByteOrder, ReadBytesExt, WriteBytesExt, NetworkEndian};
use std::string::{FromUtf8Error, String};
use error::{Error, Result};
use super::coding::{OpCode, Control, Data, CloseCode};
use super::mask::{generate_mask, apply_mask};
use super::coding::{CloseCode, Control, Data, OpCode};
use super::mask::{apply_mask, generate_mask};
use crate::error::{Error, Result};
/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
@ -77,15 +78,13 @@ impl FrameHeader {
cursor.set_position(initial);
ret
}
ret => ret
ret => ret,
}
}
/// Get the size of the header formatted with given payload length.
pub fn len(&self, length: u64) -> usize {
2
+ LengthFormat::for_length(length).extra_bytes()
+ if self.mask.is_some() { 4 } else { 0 }
2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
}
/// Format a header for given payload size.
@ -93,19 +92,15 @@ impl FrameHeader {
let code: u8 = self.opcode.into();
let one = {
code
| if self.is_final { 0x80 } else { 0 }
| if self.rsv1 { 0x40 } else { 0 }
| if self.rsv2 { 0x20 } else { 0 }
| if self.rsv3 { 0x10 } else { 0 }
code | if self.is_final { 0x80 } else { 0 }
| if self.rsv1 { 0x40 } else { 0 }
| if self.rsv2 { 0x20 } else { 0 }
| if self.rsv3 { 0x10 } else { 0 }
};
let lenfmt = LengthFormat::for_length(length);
let two = {
lenfmt.length_byte()
| if self.mask.is_some() { 0x80 } else { 0 }
};
let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
output.write_all(&[one, two])?;
match lenfmt {
@ -137,7 +132,7 @@ impl FrameHeader {
let (first, second) = {
let mut head = [0u8; 2];
if cursor.read(&mut head)? != 2 {
return Ok(None)
return Ok(None);
}
trace!("Parsed headers {:?}", head);
(head[0], head[1])
@ -169,17 +164,17 @@ impl FrameHeader {
Err(err) => {
return Err(err.into());
}
Ok(read) => read
Ok(read) => read,
}
} else {
length_byte as u64
u64::from(length_byte)
}
};
let mask = if masked {
let mut mask_bytes = [0u8; 4];
if cursor.read(&mut mask_bytes)? != 4 {
return Ok(None)
return Ok(None);
} else {
Some(mask_bytes)
}
@ -190,9 +185,11 @@ impl FrameHeader {
// Disallow bad opcode
match opcode {
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into()))
return Err(Error::Protocol(
format!("Encountered invalid opcode: {}", first & 0x0F).into(),
))
}
_ => ()
_ => (),
}
let hdr = FrameHeader {
@ -216,7 +213,6 @@ pub struct Frame {
}
impl Frame {
/// Get the length of the frame.
/// This is the length of the header + the length of the payload.
#[inline]
@ -225,6 +221,12 @@ impl Frame {
self.header.len(length as u64) + length
}
/// Check if the frame is empty.
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Get a reference to the frame's header.
#[inline]
pub fn header(&self) -> &FrameHeader {
@ -285,7 +287,7 @@ impl Frame {
String::from_utf8(self.payload)
}
/// Consume the frame into a closing frame.
/// Consume the frame into a closing frame.
#[inline]
pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
match self.payload.len() {
@ -296,7 +298,10 @@ impl Frame {
let code = NetworkEndian::read_u16(&data[0..2]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
Ok(Some(CloseFrame { code, reason: text.into() }))
Ok(Some(CloseFrame {
code,
reason: text.into(),
}))
}
}
}
@ -304,16 +309,19 @@ impl Frame {
/// Create a new data frame.
#[inline]
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
debug_assert!(match opcode {
OpCode::Data(_) => true,
_ => false,
}, "Invalid opcode for data frame.");
debug_assert!(
match opcode {
OpCode::Data(_) => true,
_ => false,
},
"Invalid opcode for data frame."
);
Frame {
header: FrameHeader {
is_final,
opcode,
.. FrameHeader::default()
..FrameHeader::default()
},
payload: data,
}
@ -325,7 +333,7 @@ impl Frame {
Frame {
header: FrameHeader {
opcode: OpCode::Control(Control::Pong),
.. FrameHeader::default()
..FrameHeader::default()
},
payload: data,
}
@ -337,7 +345,7 @@ impl Frame {
Frame {
header: FrameHeader {
opcode: OpCode::Control(Control::Ping),
.. FrameHeader::default()
..FrameHeader::default()
},
payload: data,
}
@ -363,10 +371,7 @@ impl Frame {
/// Create a frame from given header and data.
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
Frame {
header,
payload,
}
Frame { header, payload }
}
/// Write a frame out to a buffer
@ -380,7 +385,8 @@ impl Frame {
impl fmt::Display for Frame {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f,
write!(
f,
"
<FRAME>
final: {}
@ -398,7 +404,11 @@ payload: 0x{}
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
self.len(),
self.payload.len(),
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>())
self.payload
.iter()
.map(|byte| format!("{:x}", byte))
.collect::<String>()
)
}
}
@ -448,7 +458,7 @@ impl LengthFormat {
match byte & 0x7F {
126 => LengthFormat::U16,
127 => LengthFormat::U64,
b => LengthFormat::U8(b)
b => LengthFormat::U8(b),
}
}
}
@ -457,20 +467,22 @@ impl LengthFormat {
mod tests {
use super::*;
use super::super::coding::{OpCode, Data};
use super::super::coding::{Data, OpCode};
use std::io::Cursor;
#[test]
fn parse() {
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
]);
let mut raw: Cursor<Vec<u8>> =
Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
assert_eq!(length, 7);
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload);
assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]);
assert_eq!(
frame.into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}
#[test]
@ -487,5 +499,4 @@ mod tests {
let view = format!("{}", f);
assert!(view.contains("payload:"));
}
}

View File

@ -1,7 +1,8 @@
use rand;
use std::cmp::min;
#[allow(deprecated)]
use std::mem::uninitialized;
use std::ptr::{copy_nonoverlapping, read_unaligned};
use rand;
/// Generate a random frame mask.
#[inline]
@ -26,11 +27,9 @@ fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) {
/// Faster version of `apply_mask()` which operates on 4-byte blocks.
#[inline]
#[allow(dead_code)]
#[allow(dead_code, clippy::cast_ptr_alignment)]
fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
let mask_u32: u32 = unsafe {
read_unaligned(mask.as_ptr() as *const u32)
};
let mask_u32: u32 = unsafe { read_unaligned(mask.as_ptr() as *const u32) };
let mut ptr = buf.as_mut_ptr();
let mut len = buf.len();
@ -40,7 +39,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
let mask_u32 = if head > 0 {
unsafe {
xor_mem(ptr, mask_u32, head);
ptr = ptr.offset(head as isize);
ptr = ptr.add(head);
}
len -= head;
if cfg!(target_endian = "big") {
@ -67,7 +66,9 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
// Possible last block.
if len > 0 {
unsafe { xor_mem(ptr, mask_u32, len); }
unsafe {
xor_mem(ptr, mask_u32, len);
}
}
}
@ -75,6 +76,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so inefficient,
// it could be done better. The compiler does not see that len is limited to 3.
unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
#[allow(deprecated)]
let mut b: u32 = uninitialized();
#[allow(trivial_casts)]
copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len);
@ -90,12 +92,10 @@ mod tests {
#[test]
fn test_apply_mask() {
let mask = [
0x6d, 0xb6, 0xb2, 0x80,
];
let mask = [0x6d, 0xb6, 0xb2, 0x80];
let unmasked = vec![
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82,
0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03,
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9,
0x12, 0x03,
];
// Check masking with proper alignment.
@ -120,6 +120,4 @@ mod tests {
assert_eq!(masked, masked_fast);
}
}
}

View File

@ -2,16 +2,17 @@
pub mod coding;
#[allow(clippy::module_inception)]
mod frame;
mod mask;
pub use self::frame::{Frame, FrameHeader};
pub use self::frame::CloseFrame;
pub use self::frame::{Frame, FrameHeader};
use std::io::{Read, Write};
use crate::error::{Error, Result};
use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result};
use log::*;
use std::io::{Read, Write};
/// A reader and writer for WebSocket frames.
#[derive(Debug)]
@ -56,7 +57,8 @@ impl<Stream> FrameSocket<Stream> {
}
impl<Stream> FrameSocket<Stream>
where Stream: Read
where
Stream: Read,
{
/// Read a frame from stream.
pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
@ -65,7 +67,8 @@ impl<Stream> FrameSocket<Stream>
}
impl<Stream> FrameSocket<Stream>
where Stream: Write
where
Stream: Write,
{
/// Write a frame to stream.
///
@ -138,8 +141,8 @@ impl FrameCodec {
// is not too big (fits into `usize`).
if length > max_size as u64 {
return Err(Error::Capacity(
format!("Message length too big: {} > {}", length, max_size).into()
))
format!("Message length too big: {} > {}", length, max_size).into(),
));
}
let input_size = cursor.get_ref().len() as u64 - cursor.position();
@ -149,19 +152,21 @@ impl FrameCodec {
if length > 0 {
cursor.take(length).read_to_end(&mut payload)?;
}
break payload
break payload;
}
}
}
// Not enough data in buffer.
let size = self.in_buffer.prepare_reserve(MIN_READ)
let size = self
.in_buffer
.prepare_reserve(MIN_READ)
.with_limit(usize::max_value())
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?
.read_from(stream)?;
if size == 0 {
trace!("no frame received");
return Ok(None)
return Ok(None);
}
};
@ -173,17 +178,15 @@ impl FrameCodec {
}
/// Write a frame to the provided stream.
pub(super) fn write_frame<Stream>(
&mut self,
stream: &mut Stream,
frame: Frame,
) -> Result<()>
pub(super) fn write_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
where
Stream: Write,
{
trace!("writing frame {}", frame);
self.out_buffer.reserve(frame.len());
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
frame
.format(&mut self.out_buffer)
.expect("Bug: can't write to vector");
self.write_pending(stream)
}
@ -211,16 +214,19 @@ mod tests {
#[test]
fn read_frames() {
let raw = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x82, 0x03, 0x03, 0x02, 0x01,
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
0x99,
]);
let mut sock = FrameSocket::new(raw);
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x03, 0x02, 0x01]
);
assert!(sock.read_frame(None).unwrap().is_none());
let (_, rest) = sock.into_inner();
@ -229,12 +235,12 @@ mod tests {
#[test]
fn from_partially_read() {
let raw = Cursor::new(vec![
0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
]);
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
assert_eq!(
sock.read_frame(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}
#[test]
@ -248,17 +254,13 @@ mod tests {
sock.write_frame(frame).unwrap();
let (buf, _) = sock.into_inner();
assert_eq!(buf, vec![
0x89, 0x02, 0x04, 0x05,
0x8a, 0x01, 0x01
]);
assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
}
#[test]
fn parse_overflow() {
let raw = Cursor::new(vec![
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
]);
let mut sock = FrameSocket::new(raw);
let _ = sock.read_frame(None); // should not crash
@ -266,11 +268,10 @@ mod tests {
#[test]
fn size_limit_hit() {
let raw = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
]);
let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::new(raw);
assert_eq!(sock.read_frame(Some(5)).unwrap_err().to_string(),
assert_eq!(
sock.read_frame(Some(5)).unwrap_err().to_string(),
"Space limit exceeded: Message length too big: 7 > 5"
);
}

View File

@ -1,17 +1,17 @@
use std::convert::{From, Into, AsRef};
use std::convert::{AsRef, From, Into};
use std::fmt;
use std::result::Result as StdResult;
use std::str;
use error::{Result, Error};
use super::frame::CloseFrame;
use crate::error::{Error, Result};
mod string_collect {
use utf8;
use utf8::DecodeError;
use error::{Error, Result};
use crate::error::{Error, Result};
#[derive(Debug)]
pub struct StringCollector {
@ -28,7 +28,8 @@ mod string_collect {
}
pub fn len(&self) -> usize {
self.data.len()
self.data
.len()
.saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
}
@ -41,7 +42,7 @@ mod string_collect {
if let Ok(text) = result {
self.data.push_str(text);
} else {
return Err(Error::Utf8)
return Err(Error::Utf8);
}
true
} else {
@ -59,7 +60,10 @@ mod string_collect {
self.data.push_str(text);
Ok(())
}
Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
Err(DecodeError::Incomplete {
valid_prefix,
incomplete_suffix,
}) => {
self.data.push_str(valid_prefix);
self.incomplete = Some(incomplete_suffix);
Ok(())
@ -82,7 +86,6 @@ mod string_collect {
}
}
}
}
use self::string_collect::StringCollector;
@ -104,11 +107,11 @@ impl IncompleteMessage {
pub fn new(message_type: IncompleteMessageType) -> Self {
IncompleteMessage {
collector: match message_type {
IncompleteMessageType::Binary =>
IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text =>
IncompleteMessageCollector::Text(StringCollector::new()),
}
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
IncompleteMessageType::Text => {
IncompleteMessageCollector::Text(StringCollector::new())
}
},
}
}
@ -130,8 +133,12 @@ impl IncompleteMessage {
// Be careful about integer overflows here.
if my_size > max_size || portion_size > max_size - my_size {
return Err(Error::Capacity(
format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into()
))
format!(
"Message too big: {} + {} > {}",
my_size, portion_size, max_size
)
.into(),
));
}
match self.collector {
@ -139,18 +146,14 @@ impl IncompleteMessage {
v.extend(tail.as_ref());
Ok(())
}
IncompleteMessageCollector::Text(ref mut t) => {
t.extend(tail)
}
IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
}
}
/// Convert an incomplete message into a complete one.
pub fn complete(self) -> Result<Message> {
match self.collector {
IncompleteMessageCollector::Binary(v) => {
Ok(Message::Binary(v))
}
IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)),
IncompleteMessageCollector::Text(t) => {
let text = t.into_string()?;
Ok(Message::Text(text))
@ -185,17 +188,18 @@ pub enum Message {
}
impl Message {
/// Create a new text WebSocket message from a stringable.
pub fn text<S>(string: S) -> Message
where S: Into<String>
where
S: Into<String>,
{
Message::Text(string.into())
}
/// Create a new binary WebSocket message by converting to Vec<u8>.
pub fn binary<B>(bin: B) -> Message
where B: Into<Vec<u8>>
where
B: Into<Vec<u8>>,
{
Message::Binary(bin.into())
}
@ -244,9 +248,9 @@ impl Message {
pub fn len(&self) -> usize {
match *self {
Message::Text(ref string) => string.len(),
Message::Binary(ref data) |
Message::Ping(ref data) |
Message::Pong(ref data) => data.len(),
Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
data.len()
}
Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
}
}
@ -261,9 +265,7 @@ impl Message {
pub fn into_data(self) -> Vec<u8> {
match self {
Message::Text(string) => string.into_bytes(),
Message::Binary(data) |
Message::Ping(data) |
Message::Pong(data) => data,
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
Message::Close(None) => Vec::new(),
Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
}
@ -273,10 +275,9 @@ impl Message {
pub fn into_text(self) -> Result<String> {
match self {
Message::Text(string) => Ok(string),
Message::Binary(data) |
Message::Ping(data) |
Message::Pong(data) => Ok(try!(
String::from_utf8(data).map_err(|err| err.utf8_error()))),
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => {
Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?)
}
Message::Close(None) => Ok(String::new()),
Message::Close(Some(frame)) => Ok(frame.reason.into_owned()),
}
@ -287,14 +288,13 @@ impl Message {
pub fn to_text(&self) -> Result<&str> {
match *self {
Message::Text(ref string) => Ok(string),
Message::Binary(ref data) |
Message::Ping(ref data) |
Message::Pong(ref data) => Ok(try!(str::from_utf8(data))),
Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
Ok(str::from_utf8(data)?)
}
Message::Close(None) => Ok(""),
Message::Close(Some(ref frame)) => Ok(&frame.reason),
}
}
}
impl From<String> for Message {
@ -358,7 +358,6 @@ mod tests {
assert!(msg.into_text().is_err());
}
#[test]
fn binary_convert_vec() {
let bin = vec![6u8, 7, 8, 9, 10, 241];

View File

@ -4,18 +4,19 @@ pub mod frame;
mod message;
pub use self::message::Message;
pub use self::frame::CloseFrame;
pub use self::message::Message;
use log::*;
use std::collections::VecDeque;
use std::io::{Read, Write, ErrorKind as IoErrorKind};
use std::io::{ErrorKind as IoErrorKind, Read, Write};
use std::mem::replace;
use error::{Error, Result};
use self::message::{IncompleteMessage, IncompleteMessageType};
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
use self::frame::{Frame, FrameCodec};
use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode};
use util::NonBlockingResult;
use self::message::{IncompleteMessage, IncompleteMessageType};
use crate::error::{Error, Result};
use crate::util::NonBlockingResult;
/// Indicates a Client or Server role of the websocket
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -147,7 +148,6 @@ impl<Stream: Read + Write> WebSocket<Stream> {
}
}
/// A context for managing WebSocket stream.
#[derive(Debug)]
pub struct WebSocketContext {
@ -182,11 +182,7 @@ impl WebSocketContext {
}
/// Create a WebSocket context that manages an post-handshake stream.
pub fn from_partially_read(
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self {
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
WebSocketContext {
frame: FrameCodec::from_partially_read(part),
..WebSocketContext::new(role, config)
@ -217,7 +213,7 @@ impl WebSocketContext {
// Thus if read blocks, just let it return WouldBlock.
if let Some(message) = self.read_message_frame(stream)? {
trace!("Received message {}", message);
return Ok(message)
return Ok(message);
}
}
}
@ -251,20 +247,14 @@ impl WebSocketContext {
}
let frame = match message {
Message::Text(data) => {
Frame::message(data.into(), OpCode::Data(OpData::Text), true)
}
Message::Binary(data) => {
Frame::message(data, OpCode::Data(OpData::Binary), true)
}
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
Message::Ping(data) => Frame::ping(data),
Message::Pong(data) => {
self.pong = Some(Frame::pong(data));
return self.write_pending(stream)
}
Message::Close(code) => {
return self.close(stream, code)
return self.write_pending(stream);
}
Message::Close(code) => return self.close(stream, code),
};
self.send_queue.push_back(frame);
@ -342,7 +332,6 @@ impl WebSocketContext {
Stream: Read + Write,
{
if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? {
// MUST be 0 unless an extension is negotiated that defines meanings
// for non-zero values. If a nonzero value is received and none of
// the negotiated extensions defines the meaning of such a nonzero
@ -351,7 +340,7 @@ impl WebSocketContext {
{
let hdr = frame.header();
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
return Err(Error::Protocol("Reserved bits are non-zero".into()))
return Err(Error::Protocol("Reserved bits are non-zero".into()));
}
}
@ -364,19 +353,22 @@ impl WebSocketContext {
} else {
// The server MUST close the connection upon receiving a
// frame that is not masked. (RFC 6455)
return Err(Error::Protocol("Received an unmasked frame from client".into()))
return Err(Error::Protocol(
"Received an unmasked frame from client".into(),
));
}
}
Role::Client => {
if frame.is_masked() {
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
return Err(Error::Protocol("Received a masked frame from server".into()))
return Err(Error::Protocol(
"Received a masked frame from server".into(),
));
}
}
}
match frame.header().opcode {
OpCode::Control(ctl) => {
match ctl {
// All control frames MUST have a payload length of 125 bytes or less
@ -387,12 +379,10 @@ impl WebSocketContext {
_ if frame.payload().len() > 125 => {
Err(Error::Protocol("Control frame too big".into()))
}
OpCtl::Close => {
Ok(self.do_close(frame.into_close()?).map(Message::Close))
}
OpCtl::Reserved(i) => {
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
}
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
OpCtl::Reserved(i) => Err(Error::Protocol(
format!("Unknown control frame type {}", i).into(),
)),
OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => {
// No ping processing while closing.
Ok(None)
@ -402,9 +392,7 @@ impl WebSocketContext {
self.pong = Some(Frame::pong(data.clone()));
Ok(Some(Message::Ping(data)))
}
OpCtl::Pong => {
Ok(Some(Message::Pong(frame.into_data())))
}
OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))),
}
}
@ -420,7 +408,9 @@ impl WebSocketContext {
if let Some(ref mut msg) = self.incomplete {
msg.extend(frame.into_data(), self.config.max_message_size)?;
} else {
return Err(Error::Protocol("Continue frame but nothing to continue".into()))
return Err(Error::Protocol(
"Continue frame but nothing to continue".into(),
));
}
if fin {
Ok(Some(self.incomplete.take().unwrap().complete()?))
@ -428,11 +418,9 @@ impl WebSocketContext {
Ok(None)
}
}
c if self.incomplete.is_some() => {
Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into()
))
}
c if self.incomplete.is_some() => Err(Error::Protocol(
format!("Received {} while waiting for more fragments", c).into(),
)),
OpData::Text | OpData::Binary => {
let msg = {
let message_type = match data {
@ -451,28 +439,27 @@ impl WebSocketContext {
Ok(None)
}
}
OpData::Reserved(i) => {
Err(Error::Protocol(format!("Unknown data frame type {}", i).into()))
}
OpData::Reserved(i) => Err(Error::Protocol(
format!("Unknown data frame type {}", i).into(),
)),
}
}
} // match opcode
} else {
// Connection closed by peer
match replace(&mut self.state, WebSocketState::Terminated) {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
Err(Error::ConnectionClosed)
}
_ => {
Err(Error::Protocol("Connection reset without closing handshake".into()))
}
_ => Err(Error::Protocol(
"Connection reset without closing handshake".into(),
)),
}
}
}
/// Received a close frame. Tells if we need to return a close frame to the user.
#[allow(clippy::option_option)]
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
debug!("Received close frame: {:?}", close);
match self.state {
@ -488,7 +475,7 @@ impl WebSocketContext {
} else {
Frame::close(Some(CloseFrame {
code: CloseCode::Protocol,
reason: "Protocol violation".into()
reason: "Protocol violation".into(),
}))
}
} else {
@ -518,8 +505,7 @@ impl WebSocketContext {
Stream: Read + Write,
{
match self.role {
Role::Server => {
}
Role::Server => {}
Role::Client => {
// 5. If the data is being sent by the client, the frame(s) MUST be
// masked as defined in Section 5.3. (RFC 6455)
@ -535,7 +521,9 @@ impl WebSocketContext {
match self.state {
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged
if err.kind() == IoErrorKind::ConnectionReset =>
Error::ConnectionClosed,
{
Error::ConnectionClosed
}
_ => Error::Io(err),
}
}),
@ -544,7 +532,6 @@ impl WebSocketContext {
}
}
/// The current connection state.
#[derive(Debug)]
enum WebSocketState {
@ -580,7 +567,7 @@ impl WebSocketState {
#[cfg(test)]
mod tests {
use super::{WebSocket, Role, Message, WebSocketConfig};
use super::{Message, Role, WebSocket, WebSocketConfig};
use std::io;
use std::io::Cursor;
@ -602,57 +589,53 @@ mod tests {
}
}
#[test]
fn receive_messages() {
let incoming = Cursor::new(vec![
0x89, 0x02, 0x01, 0x02,
0x8a, 0x01, 0x03,
0x01, 0x07,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
0x82, 0x03,
0x01, 0x02, 0x03,
0x89, 0x02, 0x01, 0x02, 0x8a, 0x01, 0x03, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f,
0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02,
0x03,
]);
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
assert_eq!(
socket.read_message().unwrap(),
Message::Text("Hello, World!".into())
);
assert_eq!(
socket.read_message().unwrap(),
Message::Binary(vec![0x01, 0x02, 0x03])
);
}
#[test]
fn size_limiting_text_fragmented() {
let incoming = Cursor::new(vec![
0x01, 0x07,
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
0x80, 0x06,
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72,
0x6c, 0x64, 0x21,
]);
let limit = WebSocketConfig {
max_message_size: Some(10),
.. WebSocketConfig::default()
..WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(socket.read_message().unwrap_err().to_string(),
assert_eq!(
socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 7 + 6 > 10"
);
}
#[test]
fn size_limiting_binary() {
let incoming = Cursor::new(vec![
0x82, 0x03,
0x01, 0x02, 0x03,
]);
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
let limit = WebSocketConfig {
max_message_size: Some(2),
.. WebSocketConfig::default()
..WebSocketConfig::default()
};
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
assert_eq!(socket.read_message().unwrap_err().to_string(),
assert_eq!(
socket.read_message().unwrap_err().to_string(),
"Space limit exceeded: Message too big: 0 + 3 > 2"
);
}

View File

@ -1,11 +1,11 @@
//! Methods to accept an incoming WebSocket connection on a server.
pub use handshake::server::ServerHandshake;
pub use crate::handshake::server::ServerHandshake;
use handshake::HandshakeError;
use handshake::server::{Callback, NoCallback};
use crate::handshake::server::{Callback, NoCallback};
use crate::handshake::HandshakeError;
use protocol::{WebSocket, WebSocketConfig};
use crate::protocol::{WebSocket, WebSocketConfig};
use std::io::{Read, Write};
@ -18,9 +18,10 @@ use std::io::{Read, Write};
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept_with_config<S: Read + Write>(stream: S, config: Option<WebSocketConfig>)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
{
pub fn accept_with_config<S: Read + Write>(
stream: S,
config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
accept_hdr_with_config(stream, NoCallback, config)
}
@ -30,9 +31,9 @@ pub fn accept_with_config<S: Read + Write>(stream: S, config: Option<WebSocketCo
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
/// for the stream here. Any `Read + Write` streams are supported, including
/// those from `Mio` and others.
pub fn accept<S: Read + Write>(stream: S)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
{
pub fn accept<S: Read + Write>(
stream: S,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
accept_with_config(stream, None)
}
@ -47,7 +48,7 @@ pub fn accept<S: Read + Write>(stream: S)
pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
stream: S,
callback: C,
config: Option<WebSocketConfig>
config: Option<WebSocketConfig>,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
ServerHandshake::start(stream, callback, config).handshake()
}
@ -57,8 +58,9 @@ pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
/// This function does the same as `accept()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
pub fn accept_hdr<S: Read + Write, C: Callback>(stream: S, callback: C)
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>>
{
pub fn accept_hdr<S: Read + Write, C: Callback>(
stream: S,
callback: C,
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
accept_hdr_with_config(stream, callback, None)
}

View File

@ -4,11 +4,11 @@
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits.
use std::io::{Read, Write, Result as IoResult};
use std::io::{Read, Result as IoResult, Write};
use std::net::TcpStream;
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
use native_tls::TlsStream;
/// Stream mode, either plain TCP or TLS.
@ -32,7 +32,7 @@ impl NoDelay for TcpStream {
}
}
#[cfg(feature="tls")]
#[cfg(feature = "tls")]
impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.get_mut().set_nodelay(nodelay)

View File

@ -3,7 +3,7 @@
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::result::Result as StdResult;
use error::Error;
use crate::error::Error;
/// Non-blocking IO handling.
pub trait NonBlockingError: Sized {
@ -40,7 +40,8 @@ pub trait NonBlockingResult {
}
impl<T, E> NonBlockingResult for StdResult<T, E>
where E : NonBlockingError
where
E: NonBlockingError,
{
type Result = StdResult<Option<T>, E>;
fn no_block(self) -> Self::Result {
@ -49,7 +50,7 @@ impl<T, E> NonBlockingResult for StdResult<T, E>
Err(e) => match e.into_non_blocking() {
Some(e) => Err(e),
None => Ok(None),
}
},
}
}
}

View File

@ -1,13 +1,9 @@
//! Verifies that the server returns a `ConnectionClosed` error when the connection
//! is closedd from the server's point of view and drop the underlying tcp socket.
extern crate env_logger;
extern crate tungstenite;
extern crate url;
use std::net::TcpListener;
use std::process::exit;
use std::thread::{spawn, sleep};
use std::thread::{sleep, spawn};
use std::time::Duration;
use tungstenite::{accept, connect, Error, Message};
@ -28,14 +24,16 @@ fn test_close() {
let client_thread = spawn(move || {
let (mut client, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap()).unwrap();
client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
client
.write_message(Message::Text("Hello WebSocket".into()))
.unwrap();
let message = client.read_message().unwrap(); // receive close from server
assert!(message.is_close());
let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => { },
Error::ConnectionClosed => {}
_ => panic!("unexpected error"),
}
});
@ -52,7 +50,7 @@ fn test_close() {
let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => { },
Error::ConnectionClosed => {}
_ => panic!("unexpected error"),
}