Edition 2018, formatting, clippy fixes
This commit is contained in:
parent
b40256eedd
commit
cbf80ecc76
|
@ -10,6 +10,7 @@ homepage = "https://github.com/snapview/tungstenite-rs"
|
|||
documentation = "https://docs.rs/tungstenite/0.9.1"
|
||||
repository = "https://github.com/snapview/tungstenite-rs"
|
||||
version = "0.9.1"
|
||||
edition = "2018"
|
||||
|
||||
[features]
|
||||
default = ["tls"]
|
||||
|
|
|
@ -1,18 +1,12 @@
|
|||
#[macro_use] extern crate log;
|
||||
extern crate env_logger;
|
||||
extern crate tungstenite;
|
||||
extern crate url;
|
||||
|
||||
use log::*;
|
||||
use url::Url;
|
||||
|
||||
use tungstenite::{connect, Error, Result, Message};
|
||||
use tungstenite::{connect, Error, Message, Result};
|
||||
|
||||
const AGENT: &'static str = "Tungstenite";
|
||||
|
||||
fn get_case_count() -> Result<u32> {
|
||||
let (mut socket, _) = connect(
|
||||
Url::parse("ws://localhost:9001/getCaseCount").unwrap(),
|
||||
)?;
|
||||
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
|
||||
let msg = socket.read_message()?;
|
||||
socket.close(None)?;
|
||||
Ok(msg.into_text()?.parse::<u32>().unwrap())
|
||||
|
@ -20,7 +14,11 @@ fn get_case_count() -> Result<u32> {
|
|||
|
||||
fn update_reports() -> Result<()> {
|
||||
let (mut socket, _) = connect(
|
||||
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap(),
|
||||
Url::parse(&format!(
|
||||
"ws://localhost:9001/updateReports?agent={}",
|
||||
AGENT
|
||||
))
|
||||
.unwrap(),
|
||||
)?;
|
||||
socket.close(None)?;
|
||||
Ok(())
|
||||
|
@ -28,19 +26,18 @@ fn update_reports() -> Result<()> {
|
|||
|
||||
fn run_test(case: u32) -> Result<()> {
|
||||
info!("Running test case {}", case);
|
||||
let case_url = Url::parse(
|
||||
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
|
||||
).unwrap();
|
||||
let case_url = Url::parse(&format!(
|
||||
"ws://localhost:9001/runCase?case={}&agent={}",
|
||||
case, AGENT
|
||||
))
|
||||
.unwrap();
|
||||
let (mut socket, _) = connect(case_url)?;
|
||||
loop {
|
||||
match socket.read_message()? {
|
||||
msg @ Message::Text(_) |
|
||||
msg @ Message::Binary(_) => {
|
||||
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
|
||||
socket.write_message(msg)?;
|
||||
}
|
||||
Message::Ping(_) |
|
||||
Message::Pong(_) |
|
||||
Message::Close(_) => {}
|
||||
Message::Ping(_) | Message::Pong(_) | Message::Close(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -53,12 +50,13 @@ fn main() {
|
|||
for case in 1..(total + 1) {
|
||||
if let Err(e) = run_test(case) {
|
||||
match e {
|
||||
Error::Protocol(_) => { }
|
||||
err => { warn!("test: {}", err); }
|
||||
Error::Protocol(_) => {}
|
||||
err => {
|
||||
warn!("test: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
update_reports().unwrap();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
#[macro_use] extern crate log;
|
||||
extern crate env_logger;
|
||||
extern crate tungstenite;
|
||||
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::thread::spawn;
|
||||
|
||||
use tungstenite::{accept, HandshakeError, Error, Result, Message};
|
||||
use log::*;
|
||||
use tungstenite::handshake::HandshakeRole;
|
||||
use tungstenite::{accept, Error, HandshakeError, Message, Result};
|
||||
|
||||
fn must_not_block<Role: HandshakeRole>(err: HandshakeError<Role>) -> Error {
|
||||
match err {
|
||||
|
@ -19,13 +16,10 @@ fn handle_client(stream: TcpStream) -> Result<()> {
|
|||
let mut socket = accept(stream).map_err(must_not_block)?;
|
||||
loop {
|
||||
match socket.read_message()? {
|
||||
msg @ Message::Text(_) |
|
||||
msg @ Message::Binary(_) => {
|
||||
msg @ Message::Text(_) | msg @ Message::Binary(_) => {
|
||||
socket.write_message(msg)?;
|
||||
}
|
||||
Message::Ping(_) |
|
||||
Message::Pong(_) |
|
||||
Message::Close(_) => {}
|
||||
Message::Ping(_) | Message::Pong(_) | Message::Close(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -36,14 +30,12 @@ fn main() {
|
|||
let server = TcpListener::bind("127.0.0.1:9002").unwrap();
|
||||
|
||||
for stream in server.incoming() {
|
||||
spawn(move || {
|
||||
match stream {
|
||||
Ok(stream) => match handle_client(stream) {
|
||||
Ok(_) => (),
|
||||
Err(e) => warn!("Error in client: {}", e),
|
||||
},
|
||||
Err(e) => warn!("Error accepting stream: {}", e),
|
||||
}
|
||||
spawn(move || match stream {
|
||||
Ok(stream) => match handle_client(stream) {
|
||||
Ok(_) => (),
|
||||
Err(e) => warn!("Error in client: {}", e),
|
||||
},
|
||||
Err(e) => warn!("Error accepting stream: {}", e),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
extern crate tungstenite;
|
||||
|
||||
use std::thread::spawn;
|
||||
use std::net::TcpListener;
|
||||
use std::thread::spawn;
|
||||
|
||||
use tungstenite::accept_hdr;
|
||||
use tungstenite::handshake::server::{Request, ErrorResponse};
|
||||
use tungstenite::handshake::server::{ErrorResponse, Request};
|
||||
use tungstenite::http::StatusCode;
|
||||
|
||||
fn main() {
|
||||
|
|
|
@ -1,15 +1,11 @@
|
|||
extern crate tungstenite;
|
||||
extern crate url;
|
||||
extern crate env_logger;
|
||||
|
||||
use tungstenite::{connect, Message};
|
||||
use url::Url;
|
||||
use tungstenite::{Message, connect};
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
|
||||
let (mut socket, response) = connect(Url::parse("ws://localhost:3012/socket").unwrap())
|
||||
.expect("Can't connect");
|
||||
let (mut socket, response) =
|
||||
connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect");
|
||||
|
||||
println!("Connected to the server");
|
||||
println!("Response HTTP code: {}", response.code);
|
||||
|
@ -18,11 +14,12 @@ fn main() {
|
|||
println!("* {}", header);
|
||||
}
|
||||
|
||||
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
|
||||
socket
|
||||
.write_message(Message::Text("Hello WebSocket".into()))
|
||||
.unwrap();
|
||||
loop {
|
||||
let msg = socket.read_message().expect("Error reading message");
|
||||
println!("Received: {}", msg);
|
||||
}
|
||||
// socket.close(None);
|
||||
|
||||
}
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
extern crate tungstenite;
|
||||
extern crate env_logger;
|
||||
|
||||
use std::thread::spawn;
|
||||
use std::net::TcpListener;
|
||||
use std::thread::spawn;
|
||||
|
||||
use tungstenite::accept_hdr;
|
||||
use tungstenite::handshake::server::Request;
|
||||
|
@ -23,7 +20,10 @@ fn main() {
|
|||
// Let's add an additional header to our response to the client.
|
||||
let extra_headers = vec![
|
||||
(String::from("MyCustomHeader"), String::from(":)")),
|
||||
(String::from("SOME_TUNGSTENITE_HEADER"), String::from("header_value")),
|
||||
(
|
||||
String::from("SOME_TUNGSTENITE_HEADER"),
|
||||
String::from("header_value"),
|
||||
),
|
||||
];
|
||||
Ok(Some(extra_headers))
|
||||
};
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
[package]
|
||||
name = "tungstenite-fuzz"
|
||||
version = "0.0.1"
|
||||
|
|
|
@ -1,36 +1,40 @@
|
|||
//! Methods to connect to an WebSocket as a client.
|
||||
|
||||
use std::net::{TcpStream, SocketAddr, ToSocketAddrs};
|
||||
use std::result::Result as StdResult;
|
||||
use std::io::{Read, Write};
|
||||
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
|
||||
use std::result::Result as StdResult;
|
||||
|
||||
use log::*;
|
||||
use url::Url;
|
||||
|
||||
use handshake::client::Response;
|
||||
use protocol::WebSocketConfig;
|
||||
use crate::handshake::client::Response;
|
||||
use crate::protocol::WebSocketConfig;
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
#[cfg(feature = "tls")]
|
||||
mod encryption {
|
||||
use std::net::TcpStream;
|
||||
use native_tls::{TlsConnector, HandshakeError as TlsHandshakeError};
|
||||
pub use native_tls::TlsStream;
|
||||
use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector};
|
||||
use std::net::TcpStream;
|
||||
|
||||
pub use stream::Stream as StreamSwitcher;
|
||||
pub use crate::stream::Stream as StreamSwitcher;
|
||||
/// TCP stream switcher (plain/TLS).
|
||||
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
|
||||
|
||||
use stream::Mode;
|
||||
use error::Result;
|
||||
use crate::error::Result;
|
||||
use crate::stream::Mode;
|
||||
|
||||
pub fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
|
||||
match mode {
|
||||
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
|
||||
Mode::Tls => {
|
||||
let connector = TlsConnector::builder().build()?;
|
||||
connector.connect(domain, stream)
|
||||
connector
|
||||
.connect(domain, stream)
|
||||
.map_err(|e| match e {
|
||||
TlsHandshakeError::Failure(f) => f.into(),
|
||||
TlsHandshakeError::WouldBlock(_) => panic!("Bug: TLS handshake not blocked"),
|
||||
TlsHandshakeError::WouldBlock(_) => {
|
||||
panic!("Bug: TLS handshake not blocked")
|
||||
}
|
||||
})
|
||||
.map(StreamSwitcher::Tls)
|
||||
}
|
||||
|
@ -38,12 +42,12 @@ mod encryption {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature="tls"))]
|
||||
#[cfg(not(feature = "tls"))]
|
||||
mod encryption {
|
||||
use std::net::TcpStream;
|
||||
|
||||
use stream::Mode;
|
||||
use error::{Error, Result};
|
||||
use stream::Mode;
|
||||
|
||||
/// TLS support is nod compiled in, this is just standard `TcpStream`.
|
||||
pub type AutoStream = TcpStream;
|
||||
|
@ -56,15 +60,14 @@ mod encryption {
|
|||
}
|
||||
}
|
||||
|
||||
pub use self::encryption::AutoStream;
|
||||
use self::encryption::wrap_stream;
|
||||
pub use self::encryption::AutoStream;
|
||||
|
||||
use protocol::WebSocket;
|
||||
use handshake::HandshakeError;
|
||||
use handshake::client::{ClientHandshake, Request};
|
||||
use stream::{NoDelay, Mode};
|
||||
use error::{Error, Result};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::handshake::client::{ClientHandshake, Request};
|
||||
use crate::handshake::HandshakeError;
|
||||
use crate::protocol::WebSocket;
|
||||
use crate::stream::{Mode, NoDelay};
|
||||
|
||||
/// Connect to the given WebSocket in blocking mode.
|
||||
///
|
||||
|
@ -83,13 +86,17 @@ use error::{Error, Result};
|
|||
/// `connect` since it's the only function that uses native_tls.
|
||||
pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
|
||||
request: Req,
|
||||
config: Option<WebSocketConfig>
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Result<(WebSocket<AutoStream>, Response)> {
|
||||
let request: Request = request.into();
|
||||
let mode = url_mode(&request.url)?;
|
||||
let host = request.url.host()
|
||||
let host = request
|
||||
.url
|
||||
.host()
|
||||
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
|
||||
let port = request.url.port_or_known_default()
|
||||
let port = request
|
||||
.url
|
||||
.port_or_known_default()
|
||||
.ok_or_else(|| Error::Url("No port number in the URL".into()))?;
|
||||
let addrs;
|
||||
let addr;
|
||||
|
@ -109,11 +116,10 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
|
|||
};
|
||||
let mut stream = connect_to_some(addrs, &request.url, mode)?;
|
||||
NoDelay::set_nodelay(&mut stream, true)?;
|
||||
client_with_config(request, stream, config)
|
||||
.map_err(|e| match e {
|
||||
HandshakeError::Failure(f) => f,
|
||||
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
|
||||
})
|
||||
client_with_config(request, stream, config).map_err(|e| match e {
|
||||
HandshakeError::Failure(f) => f,
|
||||
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Connect to the given WebSocket in blocking mode.
|
||||
|
@ -128,19 +134,21 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
|
|||
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
|
||||
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
|
||||
/// `connect` since it's the only function that uses native_tls.
|
||||
pub fn connect<'t, Req: Into<Request<'t>>>(request: Req)
|
||||
-> Result<(WebSocket<AutoStream>, Response)>
|
||||
{
|
||||
pub fn connect<'t, Req: Into<Request<'t>>>(
|
||||
request: Req,
|
||||
) -> Result<(WebSocket<AutoStream>, Response)> {
|
||||
connect_with_config(request, None)
|
||||
}
|
||||
|
||||
fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> {
|
||||
let domain = url.host_str().ok_or_else(|| Error::Url("No host name in the URL".into()))?;
|
||||
let domain = url
|
||||
.host_str()
|
||||
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
|
||||
for addr in addrs {
|
||||
debug!("Trying to contact {} at {}...", url, addr);
|
||||
if let Ok(raw_stream) = TcpStream::connect(addr) {
|
||||
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
|
||||
return Ok(stream)
|
||||
return Ok(stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -155,7 +163,7 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
|
|||
match url.scheme() {
|
||||
"ws" => Ok(Mode::Plain),
|
||||
"wss" => Ok(Mode::Tls),
|
||||
_ => Err(Error::Url("URL scheme not supported".into()))
|
||||
_ => Err(Error::Url("URL scheme not supported".into())),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -182,8 +190,10 @@ where
|
|||
/// Use this function if you need a nonblocking handshake support or if you
|
||||
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
|
||||
/// Any stream supporting `Read + Write` will do.
|
||||
pub fn client<'t, Stream, Req>(request: Req, stream: Stream)
|
||||
-> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
|
||||
pub fn client<'t, Stream, Req>(
|
||||
request: Req,
|
||||
stream: Stream,
|
||||
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
|
||||
where
|
||||
Stream: Read + Write,
|
||||
Req: Into<Request<'t>>,
|
||||
|
|
12
src/error.rs
12
src/error.rs
|
@ -11,9 +11,9 @@ use std::string;
|
|||
|
||||
use httparse;
|
||||
|
||||
use protocol::Message;
|
||||
use crate::protocol::Message;
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
#[cfg(feature = "tls")]
|
||||
pub mod tls {
|
||||
//! TLS error wrapper module, feature-gated.
|
||||
pub use native_tls::Error;
|
||||
|
@ -41,7 +41,7 @@ pub enum Error {
|
|||
AlreadyClosed,
|
||||
/// Input-output error
|
||||
Io(io::Error),
|
||||
#[cfg(feature="tls")]
|
||||
#[cfg(feature = "tls")]
|
||||
/// TLS error
|
||||
Tls(tls::Error),
|
||||
/// Buffer capacity exhausted
|
||||
|
@ -64,7 +64,7 @@ impl fmt::Display for Error {
|
|||
Error::ConnectionClosed => write!(f, "Connection closed normally"),
|
||||
Error::AlreadyClosed => write!(f, "Trying to work with closed connection"),
|
||||
Error::Io(ref err) => write!(f, "IO error: {}", err),
|
||||
#[cfg(feature="tls")]
|
||||
#[cfg(feature = "tls")]
|
||||
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
|
||||
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
|
||||
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
|
||||
|
@ -82,7 +82,7 @@ impl ErrorTrait for Error {
|
|||
Error::ConnectionClosed => "A close handshake is performed",
|
||||
Error::AlreadyClosed => "Trying to read or write after getting close notification",
|
||||
Error::Io(ref err) => err.description(),
|
||||
#[cfg(feature="tls")]
|
||||
#[cfg(feature = "tls")]
|
||||
Error::Tls(ref err) => err.description(),
|
||||
Error::Capacity(ref msg) => msg.borrow(),
|
||||
Error::Protocol(ref msg) => msg.borrow(),
|
||||
|
@ -112,7 +112,7 @@ impl From<string::FromUtf8Error> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
#[cfg(feature = "tls")]
|
||||
impl From<tls::Error> for Error {
|
||||
fn from(err: tls::Error) -> Self {
|
||||
Error::Tls(err)
|
||||
|
|
|
@ -4,17 +4,15 @@ use std::borrow::Cow;
|
|||
use std::io::{Read, Write};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use base64;
|
||||
use httparse::Status;
|
||||
use httparse;
|
||||
use rand;
|
||||
use log::*;
|
||||
use url::Url;
|
||||
|
||||
use error::{Error, Result};
|
||||
use protocol::{WebSocket, WebSocketConfig, Role};
|
||||
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
|
||||
use super::headers::{FromHttparse, Headers, MAX_HEADERS};
|
||||
use super::machine::{HandshakeMachine, StageResult, TryParse};
|
||||
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
|
||||
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::protocol::{Role, WebSocket, WebSocketConfig};
|
||||
|
||||
/// Client request.
|
||||
#[derive(Debug)]
|
||||
|
@ -80,26 +78,32 @@ impl<S: Read + Write> ClientHandshake<S> {
|
|||
pub fn start(
|
||||
stream: S,
|
||||
request: Request,
|
||||
config: Option<WebSocketConfig>
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> MidHandshake<Self> {
|
||||
let key = generate_key();
|
||||
|
||||
let machine = {
|
||||
let mut req = Vec::new();
|
||||
write!(req, "\
|
||||
GET {path} HTTP/1.1\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: {key}\r\n",
|
||||
host = request.get_host(), path = request.get_path(), key = key).unwrap();
|
||||
write!(
|
||||
req,
|
||||
"\
|
||||
GET {path} HTTP/1.1\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: {key}\r\n",
|
||||
host = request.get_host(),
|
||||
path = request.get_path(),
|
||||
key = key
|
||||
)
|
||||
.unwrap();
|
||||
if let Some(eh) = request.extra_headers {
|
||||
for (k, v) in eh {
|
||||
write!(req, "{}: {}\r\n", k, v).unwrap();
|
||||
writeln!(req, "{}: {}\r", k, v).unwrap();
|
||||
}
|
||||
}
|
||||
write!(req, "\r\n").unwrap();
|
||||
writeln!(req, "\r").unwrap();
|
||||
HandshakeMachine::start_write(stream, req)
|
||||
};
|
||||
|
||||
|
@ -113,7 +117,10 @@ impl<S: Read + Write> ClientHandshake<S> {
|
|||
};
|
||||
|
||||
trace!("Client handshake initiated.");
|
||||
MidHandshake { role: client, machine }
|
||||
MidHandshake {
|
||||
role: client,
|
||||
machine,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -121,22 +128,23 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
|||
type IncomingData = Response;
|
||||
type InternalStream = S;
|
||||
type FinalResult = (WebSocket<S>, Response);
|
||||
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
|
||||
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
|
||||
{
|
||||
fn stage_finished(
|
||||
&mut self,
|
||||
finish: StageResult<Self::IncomingData, Self::InternalStream>,
|
||||
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
|
||||
Ok(match finish {
|
||||
StageResult::DoneWriting(stream) => {
|
||||
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
|
||||
}
|
||||
StageResult::DoneReading { stream, result, tail, } => {
|
||||
StageResult::DoneReading {
|
||||
stream,
|
||||
result,
|
||||
tail,
|
||||
} => {
|
||||
self.verify_data.verify_response(&result)?;
|
||||
debug!("Client handshake done.");
|
||||
let websocket = WebSocket::from_partially_read(
|
||||
stream,
|
||||
tail,
|
||||
Role::Client,
|
||||
self.config.clone(),
|
||||
);
|
||||
let websocket =
|
||||
WebSocket::from_partially_read(stream, tail, Role::Client, self.config);
|
||||
ProcessingResult::Done((websocket, result))
|
||||
}
|
||||
})
|
||||
|
@ -161,22 +169,37 @@ impl VerifyData {
|
|||
// header field contains a value that is not an ASCII case-
|
||||
// insensitive match for the value "websocket", the client MUST
|
||||
// _Fail the WebSocket Connection_. (RFC 6455)
|
||||
if !response.headers.header_is_ignore_case("Upgrade", "websocket") {
|
||||
return Err(Error::Protocol("No \"Upgrade: websocket\" in server reply".into()));
|
||||
if !response
|
||||
.headers
|
||||
.header_is_ignore_case("Upgrade", "websocket")
|
||||
{
|
||||
return Err(Error::Protocol(
|
||||
"No \"Upgrade: websocket\" in server reply".into(),
|
||||
));
|
||||
}
|
||||
// 3. If the response lacks a |Connection| header field or the
|
||||
// |Connection| header field doesn't contain a token that is an
|
||||
// ASCII case-insensitive match for the value "Upgrade", the client
|
||||
// MUST _Fail the WebSocket Connection_. (RFC 6455)
|
||||
if !response.headers.header_is_ignore_case("Connection", "Upgrade") {
|
||||
return Err(Error::Protocol("No \"Connection: upgrade\" in server reply".into()));
|
||||
if !response
|
||||
.headers
|
||||
.header_is_ignore_case("Connection", "Upgrade")
|
||||
{
|
||||
return Err(Error::Protocol(
|
||||
"No \"Connection: upgrade\" in server reply".into(),
|
||||
));
|
||||
}
|
||||
// 4. If the response lacks a |Sec-WebSocket-Accept| header field or
|
||||
// the |Sec-WebSocket-Accept| contains a value other than the
|
||||
// base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket
|
||||
// Connection_. (RFC 6455)
|
||||
if !response.headers.header_is("Sec-WebSocket-Accept", &self.accept_key) {
|
||||
return Err(Error::Protocol("Key mismatch in Sec-WebSocket-Accept".into()));
|
||||
if !response
|
||||
.headers
|
||||
.header_is("Sec-WebSocket-Accept", &self.accept_key)
|
||||
{
|
||||
return Err(Error::Protocol(
|
||||
"Key mismatch in Sec-WebSocket-Accept".into(),
|
||||
));
|
||||
}
|
||||
// 5. If the response includes a |Sec-WebSocket-Extensions| header
|
||||
// field and this header field indicates the use of an extension
|
||||
|
@ -219,7 +242,9 @@ impl TryParse for Response {
|
|||
impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
|
||||
fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result<Self> {
|
||||
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
|
||||
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
|
||||
return Err(Error::Protocol(
|
||||
"HTTP version should be 1.1 or higher".into(),
|
||||
));
|
||||
}
|
||||
Ok(Response {
|
||||
code: raw.code.expect("Bug: no HTTP response code"),
|
||||
|
@ -238,8 +263,8 @@ fn generate_key() -> String {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Response, generate_key};
|
||||
use super::super::machine::TryParse;
|
||||
use super::{generate_key, Response};
|
||||
|
||||
#[test]
|
||||
fn random_keys() {
|
||||
|
@ -262,6 +287,9 @@ mod tests {
|
|||
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
|
||||
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
|
||||
assert_eq!(resp.code, 200);
|
||||
assert_eq!(resp.headers.find_first("Content-Type"), Some(&b"text/html"[..]));
|
||||
assert_eq!(
|
||||
resp.headers.find_first("Content-Type"),
|
||||
Some(&b"text/html"[..])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
//! HTTP Request and response header handling.
|
||||
|
||||
use std::str::from_utf8;
|
||||
use std::slice;
|
||||
use std::str::from_utf8;
|
||||
|
||||
use httparse;
|
||||
use httparse::Status;
|
||||
|
||||
use error::Result;
|
||||
use super::machine::TryParse;
|
||||
use crate::error::Result;
|
||||
|
||||
/// Limit for the number of header lines.
|
||||
pub const MAX_HEADERS: usize = 124;
|
||||
|
@ -19,7 +19,6 @@ pub struct Headers {
|
|||
}
|
||||
|
||||
impl Headers {
|
||||
|
||||
/// Get first header with the given name, if any.
|
||||
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
|
||||
self.find(name).next()
|
||||
|
@ -29,7 +28,7 @@ impl Headers {
|
|||
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
|
||||
HeadersIter {
|
||||
name,
|
||||
iter: self.data.iter()
|
||||
iter: self.data.iter(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,7 +41,8 @@ impl Headers {
|
|||
|
||||
/// Check if the given header has the given value (case-insensitive).
|
||||
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
|
||||
self.find_first(name).ok_or(())
|
||||
self.find_first(name)
|
||||
.ok_or(())
|
||||
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
|
||||
.map(|val| val.eq_ignore_ascii_case(value))
|
||||
.unwrap_or(false)
|
||||
|
@ -52,7 +52,6 @@ impl Headers {
|
|||
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
|
||||
self.data.iter()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// The iterator over headers.
|
||||
|
@ -67,14 +66,13 @@ impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
|
|||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(&(ref name, ref value)) = self.iter.next() {
|
||||
if name.eq_ignore_ascii_case(self.name) {
|
||||
return Some(value)
|
||||
return Some(value);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Trait to convert raw objects into HTTP parseables.
|
||||
pub trait FromHttparse<T>: Sized {
|
||||
/// Convert raw object into parsed HTTP headers.
|
||||
|
@ -94,9 +92,10 @@ impl TryParse for Headers {
|
|||
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
|
||||
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
|
||||
Ok(Headers {
|
||||
data: raw.iter()
|
||||
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
|
||||
.collect(),
|
||||
data: raw
|
||||
.iter()
|
||||
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -104,13 +103,12 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::Headers;
|
||||
use super::super::machine::TryParse;
|
||||
use super::Headers;
|
||||
|
||||
#[test]
|
||||
fn headers() {
|
||||
const DATA: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
const DATA: &'static [u8] = b"Host: foo.com\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
\r\n";
|
||||
|
@ -126,8 +124,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn headers_iter() {
|
||||
const DATA: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
const DATA: &'static [u8] = b"Host: foo.com\r\n\
|
||||
Sec-WebSocket-Extensions: permessage-deflate\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
|
||||
|
@ -142,12 +139,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn headers_incomplete() {
|
||||
const DATA: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
const DATA: &'static [u8] = b"Host: foo.com\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n";
|
||||
let hdr = Headers::try_parse(DATA).unwrap();
|
||||
assert!(hdr.is_none());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use std::io::{Cursor, Read, Write};
|
||||
use bytes::Buf;
|
||||
use log::*;
|
||||
use std::io::{Cursor, Read, Write};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::util::NonBlockingResult;
|
||||
use input_buffer::{InputBuffer, MIN_READ};
|
||||
use error::{Error, Result};
|
||||
use util::NonBlockingResult;
|
||||
|
||||
/// A generic handshake state machine.
|
||||
#[derive(Debug)]
|
||||
|
@ -43,16 +44,16 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
|
|||
trace!("Doing handshake round.");
|
||||
match self.state {
|
||||
HandshakeState::Reading(mut buf) => {
|
||||
let read = buf.prepare_reserve(MIN_READ)
|
||||
let read = buf
|
||||
.prepare_reserve(MIN_READ)
|
||||
.with_limit(usize::max_value()) // TODO limit size
|
||||
.map_err(|_| Error::Capacity("Header too long".into()))?
|
||||
.read_from(&mut self.stream).no_block()?;
|
||||
.read_from(&mut self.stream)
|
||||
.no_block()?;
|
||||
match read {
|
||||
Some(0) => {
|
||||
Err(Error::Protocol("Handshake not finished".into()))
|
||||
}
|
||||
Some(_) => {
|
||||
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
|
||||
Some(0) => Err(Error::Protocol("Handshake not finished".into())),
|
||||
Some(_) => Ok(
|
||||
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
|
||||
buf.advance(size);
|
||||
RoundResult::StageFinished(StageResult::DoneReading {
|
||||
result: obj,
|
||||
|
@ -64,14 +65,12 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
|
|||
state: HandshakeState::Reading(buf),
|
||||
..self
|
||||
})
|
||||
})
|
||||
}
|
||||
None => {
|
||||
Ok(RoundResult::WouldBlock(HandshakeMachine {
|
||||
state: HandshakeState::Reading(buf),
|
||||
..self
|
||||
}))
|
||||
}
|
||||
},
|
||||
),
|
||||
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
|
||||
state: HandshakeState::Reading(buf),
|
||||
..self
|
||||
})),
|
||||
}
|
||||
}
|
||||
HandshakeState::Writing(mut buf) => {
|
||||
|
@ -113,7 +112,11 @@ pub enum RoundResult<Obj, Stream> {
|
|||
#[derive(Debug)]
|
||||
pub enum StageResult<Obj, Stream> {
|
||||
/// Reading round finished.
|
||||
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
|
||||
DoneReading {
|
||||
result: Obj,
|
||||
stream: Stream,
|
||||
tail: Vec<u8>,
|
||||
},
|
||||
/// Writing round finished.
|
||||
DoneWriting(Stream),
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
//! WebSocket handshake control.
|
||||
|
||||
pub mod headers;
|
||||
pub mod client;
|
||||
pub mod headers;
|
||||
pub mod server;
|
||||
|
||||
mod machine;
|
||||
|
@ -11,10 +11,10 @@ use std::fmt;
|
|||
use std::io::{Read, Write};
|
||||
|
||||
use base64;
|
||||
use sha1::{Sha1, Digest};
|
||||
use sha1::{Digest, Sha1};
|
||||
|
||||
use error::Error;
|
||||
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
|
||||
use crate::error::Error;
|
||||
|
||||
/// A WebSocket handshake.
|
||||
#[derive(Debug)]
|
||||
|
@ -30,15 +30,16 @@ impl<Role: HandshakeRole> MidHandshake<Role> {
|
|||
loop {
|
||||
mach = match mach.single_round()? {
|
||||
RoundResult::WouldBlock(m) => {
|
||||
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
|
||||
return Err(HandshakeError::Interrupted(MidHandshake {
|
||||
machine: m,
|
||||
..self
|
||||
}))
|
||||
}
|
||||
RoundResult::Incomplete(m) => m,
|
||||
RoundResult::StageFinished(s) => {
|
||||
match self.role.stage_finished(s)? {
|
||||
ProcessingResult::Continue(m) => m,
|
||||
ProcessingResult::Done(result) => return Ok(result),
|
||||
}
|
||||
}
|
||||
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
|
||||
ProcessingResult::Continue(m) => m,
|
||||
ProcessingResult::Done(result) => return Ok(result),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -94,8 +95,10 @@ pub trait HandshakeRole {
|
|||
#[doc(hidden)]
|
||||
type FinalResult;
|
||||
#[doc(hidden)]
|
||||
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
|
||||
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
|
||||
fn stage_finished(
|
||||
&mut self,
|
||||
finish: StageResult<Self::IncomingData, Self::InternalStream>,
|
||||
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
|
||||
}
|
||||
|
||||
/// Stage processing result.
|
||||
|
@ -124,8 +127,9 @@ mod tests {
|
|||
#[test]
|
||||
fn key_conversion() {
|
||||
// example from RFC 6455
|
||||
assert_eq!(convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(),
|
||||
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
|
||||
assert_eq!(
|
||||
convert_key(b"dGhlIHNhbXBsZSBub25jZQ==").unwrap(),
|
||||
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -5,15 +5,15 @@ use std::io::{Read, Write};
|
|||
use std::marker::PhantomData;
|
||||
use std::result::Result as StdResult;
|
||||
|
||||
use httparse;
|
||||
use httparse::Status;
|
||||
use http::StatusCode;
|
||||
use httparse::Status;
|
||||
use log::*;
|
||||
|
||||
use error::{Error, Result};
|
||||
use protocol::{WebSocket, WebSocketConfig, Role};
|
||||
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
|
||||
use super::headers::{FromHttparse, Headers, MAX_HEADERS};
|
||||
use super::machine::{HandshakeMachine, StageResult, TryParse};
|
||||
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
|
||||
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::protocol::{Role, WebSocket, WebSocketConfig};
|
||||
|
||||
/// Request from the client.
|
||||
#[derive(Debug)]
|
||||
|
@ -27,14 +27,16 @@ pub struct Request {
|
|||
impl Request {
|
||||
/// Reply to the response.
|
||||
pub fn reply(&self, extra_headers: Option<Vec<(String, String)>>) -> Result<Vec<u8>> {
|
||||
let key = self.headers.find_first("Sec-WebSocket-Key")
|
||||
let key = self
|
||||
.headers
|
||||
.find_first("Sec-WebSocket-Key")
|
||||
.ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?;
|
||||
let mut reply = format!(
|
||||
"\
|
||||
HTTP/1.1 101 Switching Protocols\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Accept: {}\r\n",
|
||||
HTTP/1.1 101 Switching Protocols\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Accept: {}\r\n",
|
||||
convert_key(key)?
|
||||
);
|
||||
add_headers(&mut reply, extra_headers);
|
||||
|
@ -45,13 +47,12 @@ impl Request {
|
|||
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) {
|
||||
if let Some(eh) = extra_headers {
|
||||
for (k, v) in eh {
|
||||
write!(reply, "{}: {}\r\n", k, v).unwrap();
|
||||
writeln!(reply, "{}: {}\r", k, v).unwrap();
|
||||
}
|
||||
}
|
||||
write!(reply, "\r\n").unwrap();
|
||||
writeln!(reply, "\r").unwrap();
|
||||
}
|
||||
|
||||
|
||||
impl TryParse for Request {
|
||||
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||
|
@ -69,11 +70,13 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
|
|||
return Err(Error::Protocol("Method is not GET".into()));
|
||||
}
|
||||
if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
|
||||
return Err(Error::Protocol("HTTP version should be 1.1 or higher".into()));
|
||||
return Err(Error::Protocol(
|
||||
"HTTP version should be 1.1 or higher".into(),
|
||||
));
|
||||
}
|
||||
Ok(Request {
|
||||
path: raw.path.expect("Bug: no path in header").into(),
|
||||
headers: Headers::from_httparse(raw.headers)?
|
||||
headers: Headers::from_httparse(raw.headers)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -115,7 +118,10 @@ pub trait Callback: Sized {
|
|||
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>;
|
||||
}
|
||||
|
||||
impl<F> Callback for F where F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
|
||||
impl<F> Callback for F
|
||||
where
|
||||
F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>,
|
||||
{
|
||||
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
|
||||
self(request)
|
||||
}
|
||||
|
@ -160,7 +166,7 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
|
|||
callback: Some(callback),
|
||||
config,
|
||||
error_code: None,
|
||||
_marker: PhantomData
|
||||
_marker: PhantomData,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -171,13 +177,18 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
|||
type InternalStream = S;
|
||||
type FinalResult = WebSocket<S>;
|
||||
|
||||
fn stage_finished(&mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>)
|
||||
-> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>
|
||||
{
|
||||
fn stage_finished(
|
||||
&mut self,
|
||||
finish: StageResult<Self::IncomingData, Self::InternalStream>,
|
||||
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
|
||||
Ok(match finish {
|
||||
StageResult::DoneReading { stream, result, tail } => {
|
||||
StageResult::DoneReading {
|
||||
stream,
|
||||
result,
|
||||
tail,
|
||||
} => {
|
||||
if !tail.is_empty() {
|
||||
return Err(Error::Protocol("Junk after client request".into()))
|
||||
return Err(Error::Protocol("Junk after client request".into()));
|
||||
}
|
||||
|
||||
let callback_result = if let Some(callback) = self.callback.take() {
|
||||
|
@ -192,8 +203,12 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
|||
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
|
||||
}
|
||||
|
||||
Err(ErrorResponse { error_code, headers, body }) => {
|
||||
self.error_code= Some(error_code.as_u16());
|
||||
Err(ErrorResponse {
|
||||
error_code,
|
||||
headers,
|
||||
body,
|
||||
}) => {
|
||||
self.error_code = Some(error_code.as_u16());
|
||||
let mut response = format!(
|
||||
"HTTP/1.1 {} {}\r\n",
|
||||
error_code.as_str(),
|
||||
|
@ -214,11 +229,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
|||
return Err(Error::Http(err));
|
||||
} else {
|
||||
debug!("Server handshake done.");
|
||||
let websocket = WebSocket::from_raw_socket(
|
||||
stream,
|
||||
Role::Server,
|
||||
self.config.clone(),
|
||||
);
|
||||
let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
|
||||
ProcessingResult::Done(websocket)
|
||||
}
|
||||
}
|
||||
|
@ -228,9 +239,9 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Request;
|
||||
use super::super::machine::TryParse;
|
||||
use super::super::client::Response;
|
||||
use super::super::machine::TryParse;
|
||||
use super::Request;
|
||||
|
||||
#[test]
|
||||
fn request_parsing() {
|
||||
|
@ -253,13 +264,19 @@ mod tests {
|
|||
let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
|
||||
let _ = req.reply(None).unwrap();
|
||||
|
||||
let extra_headers = Some(vec![(String::from("MyCustomHeader"),
|
||||
String::from("MyCustomValue")),
|
||||
(String::from("MyVersion"),
|
||||
String::from("LOL"))]);
|
||||
let extra_headers = Some(vec![
|
||||
(
|
||||
String::from("MyCustomHeader"),
|
||||
String::from("MyCustomValue"),
|
||||
),
|
||||
(String::from("MyVersion"), String::from("LOL")),
|
||||
]);
|
||||
let reply = req.reply(extra_headers).unwrap();
|
||||
let (_, req) = Response::try_parse(&reply).unwrap().unwrap();
|
||||
assert_eq!(req.headers.find_first("MyCustomHeader"), Some(b"MyCustomValue".as_ref()));
|
||||
assert_eq!(
|
||||
req.headers.find_first("MyCustomHeader"),
|
||||
Some(b"MyCustomValue".as_ref())
|
||||
);
|
||||
assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref()));
|
||||
}
|
||||
}
|
||||
|
|
40
src/lib.rs
40
src/lib.rs
|
@ -3,39 +3,29 @@
|
|||
missing_docs,
|
||||
missing_copy_implementations,
|
||||
missing_debug_implementations,
|
||||
trivial_casts, trivial_numeric_casts,
|
||||
trivial_casts,
|
||||
trivial_numeric_casts,
|
||||
unstable_features,
|
||||
unused_must_use,
|
||||
unused_mut,
|
||||
unused_imports,
|
||||
unused_import_braces)]
|
||||
unused_import_braces
|
||||
)]
|
||||
|
||||
#[macro_use] extern crate log;
|
||||
extern crate base64;
|
||||
extern crate byteorder;
|
||||
extern crate bytes;
|
||||
extern crate httparse;
|
||||
extern crate input_buffer;
|
||||
extern crate rand;
|
||||
extern crate sha1;
|
||||
extern crate url;
|
||||
extern crate utf8;
|
||||
#[cfg(feature="tls")] extern crate native_tls;
|
||||
pub use http;
|
||||
|
||||
pub extern crate http;
|
||||
|
||||
pub mod error;
|
||||
pub mod protocol;
|
||||
pub mod client;
|
||||
pub mod server;
|
||||
pub mod error;
|
||||
pub mod handshake;
|
||||
pub mod protocol;
|
||||
pub mod server;
|
||||
pub mod stream;
|
||||
pub mod util;
|
||||
|
||||
pub use client::{connect, client};
|
||||
pub use server::{accept, accept_hdr};
|
||||
pub use error::{Error, Result};
|
||||
pub use protocol::{WebSocket, Message};
|
||||
pub use handshake::HandshakeError;
|
||||
pub use handshake::client::ClientHandshake;
|
||||
pub use handshake::server::ServerHandshake;
|
||||
pub use crate::client::{client, connect};
|
||||
pub use crate::error::{Error, Result};
|
||||
pub use crate::handshake::client::ClientHandshake;
|
||||
pub use crate::handshake::server::ServerHandshake;
|
||||
pub use crate::handshake::HandshakeError;
|
||||
pub use crate::protocol::{Message, WebSocket};
|
||||
pub use crate::server::{accept, accept_hdr};
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
//! Various codes defined in RFC 6455.
|
||||
|
||||
use std::convert::{From, Into};
|
||||
use std::fmt;
|
||||
use std::convert::{Into, From};
|
||||
|
||||
/// WebSocket message opcode as in RFC 6455.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
|
@ -42,8 +42,8 @@ impl fmt::Display for Data {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
Data::Continue => write!(f, "CONTINUE"),
|
||||
Data::Text => write!(f, "TEXT"),
|
||||
Data::Binary => write!(f, "BINARY"),
|
||||
Data::Text => write!(f, "TEXT"),
|
||||
Data::Binary => write!(f, "BINARY"),
|
||||
Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x),
|
||||
}
|
||||
}
|
||||
|
@ -53,8 +53,8 @@ impl fmt::Display for Control {
|
|||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
Control::Close => write!(f, "CLOSE"),
|
||||
Control::Ping => write!(f, "PING"),
|
||||
Control::Pong => write!(f, "PONG"),
|
||||
Control::Ping => write!(f, "PING"),
|
||||
Control::Pong => write!(f, "PONG"),
|
||||
Control::Reserved(x) => write!(f, "RESERVED_CONTROL_{}", x),
|
||||
}
|
||||
}
|
||||
|
@ -71,18 +71,18 @@ impl fmt::Display for OpCode {
|
|||
|
||||
impl Into<u8> for OpCode {
|
||||
fn into(self) -> u8 {
|
||||
use self::Data::{Continue, Text, Binary};
|
||||
use self::Control::{Close, Ping, Pong};
|
||||
use self::Data::{Binary, Continue, Text};
|
||||
use self::OpCode::*;
|
||||
match self {
|
||||
Data(Continue) => 0,
|
||||
Data(Text) => 1,
|
||||
Data(Binary) => 2,
|
||||
Data(Text) => 1,
|
||||
Data(Binary) => 2,
|
||||
Data(self::Data::Reserved(i)) => i,
|
||||
|
||||
Control(Close) => 8,
|
||||
Control(Ping) => 9,
|
||||
Control(Pong) => 10,
|
||||
Control(Ping) => 9,
|
||||
Control(Pong) => 10,
|
||||
Control(self::Control::Reserved(i)) => i,
|
||||
}
|
||||
}
|
||||
|
@ -90,19 +90,19 @@ impl Into<u8> for OpCode {
|
|||
|
||||
impl From<u8> for OpCode {
|
||||
fn from(byte: u8) -> OpCode {
|
||||
use self::Data::{Continue, Text, Binary};
|
||||
use self::Control::{Close, Ping, Pong};
|
||||
use self::Data::{Binary, Continue, Text};
|
||||
use self::OpCode::*;
|
||||
match byte {
|
||||
0 => Data(Continue),
|
||||
1 => Data(Text),
|
||||
2 => Data(Binary),
|
||||
i @ 3 ... 7 => Data(self::Data::Reserved(i)),
|
||||
8 => Control(Close),
|
||||
9 => Control(Ping),
|
||||
10 => Control(Pong),
|
||||
i @ 11 ... 15 => Control(self::Control::Reserved(i)),
|
||||
_ => panic!("Bug: OpCode out of range"),
|
||||
0 => Data(Continue),
|
||||
1 => Data(Text),
|
||||
2 => Data(Binary),
|
||||
i @ 3..=7 => Data(self::Data::Reserved(i)),
|
||||
8 => Control(Close),
|
||||
9 => Control(Ping),
|
||||
10 => Control(Pong),
|
||||
i @ 11..=15 => Control(self::Control::Reserved(i)),
|
||||
_ => panic!("Bug: OpCode out of range"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -183,13 +183,13 @@ pub enum CloseCode {
|
|||
|
||||
impl CloseCode {
|
||||
/// Check if this CloseCode is allowed.
|
||||
pub fn is_allowed(&self) -> bool {
|
||||
match *self {
|
||||
Bad(_) => false,
|
||||
pub fn is_allowed(self) -> bool {
|
||||
match self {
|
||||
Bad(_) => false,
|
||||
Reserved(_) => false,
|
||||
Status => false,
|
||||
Abnormal => false,
|
||||
Tls => false,
|
||||
Status => false,
|
||||
Abnormal => false,
|
||||
Tls => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
@ -205,24 +205,24 @@ impl fmt::Display for CloseCode {
|
|||
impl<'t> Into<u16> for &'t CloseCode {
|
||||
fn into(self) -> u16 {
|
||||
match *self {
|
||||
Normal => 1000,
|
||||
Away => 1001,
|
||||
Protocol => 1002,
|
||||
Unsupported => 1003,
|
||||
Status => 1005,
|
||||
Abnormal => 1006,
|
||||
Invalid => 1007,
|
||||
Policy => 1008,
|
||||
Size => 1009,
|
||||
Extension => 1010,
|
||||
Error => 1011,
|
||||
Restart => 1012,
|
||||
Again => 1013,
|
||||
Tls => 1015,
|
||||
Reserved(code) => code,
|
||||
Iana(code) => code,
|
||||
Library(code) => code,
|
||||
Bad(code) => code,
|
||||
Normal => 1000,
|
||||
Away => 1001,
|
||||
Protocol => 1002,
|
||||
Unsupported => 1003,
|
||||
Status => 1005,
|
||||
Abnormal => 1006,
|
||||
Invalid => 1007,
|
||||
Policy => 1008,
|
||||
Size => 1009,
|
||||
Extension => 1010,
|
||||
Error => 1011,
|
||||
Restart => 1012,
|
||||
Again => 1013,
|
||||
Tls => 1015,
|
||||
Reserved(code) => code,
|
||||
Iana(code) => code,
|
||||
Library(code) => code,
|
||||
Bad(code) => code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -250,11 +250,11 @@ impl From<u16> for CloseCode {
|
|||
1012 => Restart,
|
||||
1013 => Again,
|
||||
1015 => Tls,
|
||||
1...999 => Bad(code),
|
||||
1000...2999 => Reserved(code),
|
||||
3000...3999 => Iana(code),
|
||||
4000...4999 => Library(code),
|
||||
_ => Bad(code)
|
||||
1..=999 => Bad(code),
|
||||
1016..=2999 => Reserved(code),
|
||||
3000..=3999 => Iana(code),
|
||||
4000..=4999 => Library(code),
|
||||
_ => Bad(code),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
use std::fmt;
|
||||
use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
|
||||
use log::*;
|
||||
use std::borrow::Cow;
|
||||
use std::io::{Cursor, Read, Write, ErrorKind};
|
||||
use std::default::Default;
|
||||
use std::string::{String, FromUtf8Error};
|
||||
use std::fmt;
|
||||
use std::io::{Cursor, ErrorKind, Read, Write};
|
||||
use std::result::Result as StdResult;
|
||||
use byteorder::{ByteOrder, ReadBytesExt, WriteBytesExt, NetworkEndian};
|
||||
use std::string::{FromUtf8Error, String};
|
||||
|
||||
use error::{Error, Result};
|
||||
use super::coding::{OpCode, Control, Data, CloseCode};
|
||||
use super::mask::{generate_mask, apply_mask};
|
||||
use super::coding::{CloseCode, Control, Data, OpCode};
|
||||
use super::mask::{apply_mask, generate_mask};
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// A struct representing the close command.
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
|
@ -77,15 +78,13 @@ impl FrameHeader {
|
|||
cursor.set_position(initial);
|
||||
ret
|
||||
}
|
||||
ret => ret
|
||||
ret => ret,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the size of the header formatted with given payload length.
|
||||
pub fn len(&self, length: u64) -> usize {
|
||||
2
|
||||
+ LengthFormat::for_length(length).extra_bytes()
|
||||
+ if self.mask.is_some() { 4 } else { 0 }
|
||||
2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
|
||||
}
|
||||
|
||||
/// Format a header for given payload size.
|
||||
|
@ -93,19 +92,15 @@ impl FrameHeader {
|
|||
let code: u8 = self.opcode.into();
|
||||
|
||||
let one = {
|
||||
code
|
||||
| if self.is_final { 0x80 } else { 0 }
|
||||
| if self.rsv1 { 0x40 } else { 0 }
|
||||
| if self.rsv2 { 0x20 } else { 0 }
|
||||
| if self.rsv3 { 0x10 } else { 0 }
|
||||
code | if self.is_final { 0x80 } else { 0 }
|
||||
| if self.rsv1 { 0x40 } else { 0 }
|
||||
| if self.rsv2 { 0x20 } else { 0 }
|
||||
| if self.rsv3 { 0x10 } else { 0 }
|
||||
};
|
||||
|
||||
let lenfmt = LengthFormat::for_length(length);
|
||||
|
||||
let two = {
|
||||
lenfmt.length_byte()
|
||||
| if self.mask.is_some() { 0x80 } else { 0 }
|
||||
};
|
||||
let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
|
||||
|
||||
output.write_all(&[one, two])?;
|
||||
match lenfmt {
|
||||
|
@ -137,7 +132,7 @@ impl FrameHeader {
|
|||
let (first, second) = {
|
||||
let mut head = [0u8; 2];
|
||||
if cursor.read(&mut head)? != 2 {
|
||||
return Ok(None)
|
||||
return Ok(None);
|
||||
}
|
||||
trace!("Parsed headers {:?}", head);
|
||||
(head[0], head[1])
|
||||
|
@ -169,17 +164,17 @@ impl FrameHeader {
|
|||
Err(err) => {
|
||||
return Err(err.into());
|
||||
}
|
||||
Ok(read) => read
|
||||
Ok(read) => read,
|
||||
}
|
||||
} else {
|
||||
length_byte as u64
|
||||
u64::from(length_byte)
|
||||
}
|
||||
};
|
||||
|
||||
let mask = if masked {
|
||||
let mut mask_bytes = [0u8; 4];
|
||||
if cursor.read(&mut mask_bytes)? != 4 {
|
||||
return Ok(None)
|
||||
return Ok(None);
|
||||
} else {
|
||||
Some(mask_bytes)
|
||||
}
|
||||
|
@ -190,9 +185,11 @@ impl FrameHeader {
|
|||
// Disallow bad opcode
|
||||
match opcode {
|
||||
OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
|
||||
return Err(Error::Protocol(format!("Encountered invalid opcode: {}", first & 0x0F).into()))
|
||||
return Err(Error::Protocol(
|
||||
format!("Encountered invalid opcode: {}", first & 0x0F).into(),
|
||||
))
|
||||
}
|
||||
_ => ()
|
||||
_ => (),
|
||||
}
|
||||
|
||||
let hdr = FrameHeader {
|
||||
|
@ -216,7 +213,6 @@ pub struct Frame {
|
|||
}
|
||||
|
||||
impl Frame {
|
||||
|
||||
/// Get the length of the frame.
|
||||
/// This is the length of the header + the length of the payload.
|
||||
#[inline]
|
||||
|
@ -225,6 +221,12 @@ impl Frame {
|
|||
self.header.len(length as u64) + length
|
||||
}
|
||||
|
||||
/// Check if the frame is empty.
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Get a reference to the frame's header.
|
||||
#[inline]
|
||||
pub fn header(&self) -> &FrameHeader {
|
||||
|
@ -285,7 +287,7 @@ impl Frame {
|
|||
String::from_utf8(self.payload)
|
||||
}
|
||||
|
||||
/// Consume the frame into a closing frame.
|
||||
/// Consume the frame into a closing frame.
|
||||
#[inline]
|
||||
pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
|
||||
match self.payload.len() {
|
||||
|
@ -296,7 +298,10 @@ impl Frame {
|
|||
let code = NetworkEndian::read_u16(&data[0..2]).into();
|
||||
data.drain(0..2);
|
||||
let text = String::from_utf8(data)?;
|
||||
Ok(Some(CloseFrame { code, reason: text.into() }))
|
||||
Ok(Some(CloseFrame {
|
||||
code,
|
||||
reason: text.into(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -304,16 +309,19 @@ impl Frame {
|
|||
/// Create a new data frame.
|
||||
#[inline]
|
||||
pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
|
||||
debug_assert!(match opcode {
|
||||
OpCode::Data(_) => true,
|
||||
_ => false,
|
||||
}, "Invalid opcode for data frame.");
|
||||
debug_assert!(
|
||||
match opcode {
|
||||
OpCode::Data(_) => true,
|
||||
_ => false,
|
||||
},
|
||||
"Invalid opcode for data frame."
|
||||
);
|
||||
|
||||
Frame {
|
||||
header: FrameHeader {
|
||||
is_final,
|
||||
opcode,
|
||||
.. FrameHeader::default()
|
||||
..FrameHeader::default()
|
||||
},
|
||||
payload: data,
|
||||
}
|
||||
|
@ -325,7 +333,7 @@ impl Frame {
|
|||
Frame {
|
||||
header: FrameHeader {
|
||||
opcode: OpCode::Control(Control::Pong),
|
||||
.. FrameHeader::default()
|
||||
..FrameHeader::default()
|
||||
},
|
||||
payload: data,
|
||||
}
|
||||
|
@ -337,7 +345,7 @@ impl Frame {
|
|||
Frame {
|
||||
header: FrameHeader {
|
||||
opcode: OpCode::Control(Control::Ping),
|
||||
.. FrameHeader::default()
|
||||
..FrameHeader::default()
|
||||
},
|
||||
payload: data,
|
||||
}
|
||||
|
@ -363,10 +371,7 @@ impl Frame {
|
|||
|
||||
/// Create a frame from given header and data.
|
||||
pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
|
||||
Frame {
|
||||
header,
|
||||
payload,
|
||||
}
|
||||
Frame { header, payload }
|
||||
}
|
||||
|
||||
/// Write a frame out to a buffer
|
||||
|
@ -380,7 +385,8 @@ impl Frame {
|
|||
|
||||
impl fmt::Display for Frame {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f,
|
||||
write!(
|
||||
f,
|
||||
"
|
||||
<FRAME>
|
||||
final: {}
|
||||
|
@ -398,7 +404,11 @@ payload: 0x{}
|
|||
// self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
|
||||
self.len(),
|
||||
self.payload.len(),
|
||||
self.payload.iter().map(|byte| format!("{:x}", byte)).collect::<String>())
|
||||
self.payload
|
||||
.iter()
|
||||
.map(|byte| format!("{:x}", byte))
|
||||
.collect::<String>()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -448,7 +458,7 @@ impl LengthFormat {
|
|||
match byte & 0x7F {
|
||||
126 => LengthFormat::U16,
|
||||
127 => LengthFormat::U64,
|
||||
b => LengthFormat::U8(b)
|
||||
b => LengthFormat::U8(b),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -457,20 +467,22 @@ impl LengthFormat {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use super::super::coding::{OpCode, Data};
|
||||
use super::super::coding::{Data, OpCode};
|
||||
use std::io::Cursor;
|
||||
|
||||
#[test]
|
||||
fn parse() {
|
||||
let mut raw: Cursor<Vec<u8>> = Cursor::new(vec![
|
||||
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
|
||||
]);
|
||||
let mut raw: Cursor<Vec<u8>> =
|
||||
Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
|
||||
let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
|
||||
assert_eq!(length, 7);
|
||||
let mut payload = Vec::new();
|
||||
raw.read_to_end(&mut payload).unwrap();
|
||||
let frame = Frame::from_payload(header, payload);
|
||||
assert_eq!(frame.into_data(), vec![ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07 ]);
|
||||
assert_eq!(
|
||||
frame.into_data(),
|
||||
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -487,5 +499,4 @@ mod tests {
|
|||
let view = format!("{}", f);
|
||||
assert!(view.contains("payload:"));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use rand;
|
||||
use std::cmp::min;
|
||||
#[allow(deprecated)]
|
||||
use std::mem::uninitialized;
|
||||
use std::ptr::{copy_nonoverlapping, read_unaligned};
|
||||
use rand;
|
||||
|
||||
/// Generate a random frame mask.
|
||||
#[inline]
|
||||
|
@ -26,11 +27,9 @@ fn apply_mask_fallback(buf: &mut [u8], mask: [u8; 4]) {
|
|||
|
||||
/// Faster version of `apply_mask()` which operates on 4-byte blocks.
|
||||
#[inline]
|
||||
#[allow(dead_code)]
|
||||
#[allow(dead_code, clippy::cast_ptr_alignment)]
|
||||
fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
|
||||
let mask_u32: u32 = unsafe {
|
||||
read_unaligned(mask.as_ptr() as *const u32)
|
||||
};
|
||||
let mask_u32: u32 = unsafe { read_unaligned(mask.as_ptr() as *const u32) };
|
||||
|
||||
let mut ptr = buf.as_mut_ptr();
|
||||
let mut len = buf.len();
|
||||
|
@ -40,7 +39,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
|
|||
let mask_u32 = if head > 0 {
|
||||
unsafe {
|
||||
xor_mem(ptr, mask_u32, head);
|
||||
ptr = ptr.offset(head as isize);
|
||||
ptr = ptr.add(head);
|
||||
}
|
||||
len -= head;
|
||||
if cfg!(target_endian = "big") {
|
||||
|
@ -67,7 +66,9 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
|
|||
|
||||
// Possible last block.
|
||||
if len > 0 {
|
||||
unsafe { xor_mem(ptr, mask_u32, len); }
|
||||
unsafe {
|
||||
xor_mem(ptr, mask_u32, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,6 +76,7 @@ fn apply_mask_fast32(buf: &mut [u8], mask: [u8; 4]) {
|
|||
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so inefficient,
|
||||
// it could be done better. The compiler does not see that len is limited to 3.
|
||||
unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
|
||||
#[allow(deprecated)]
|
||||
let mut b: u32 = uninitialized();
|
||||
#[allow(trivial_casts)]
|
||||
copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len);
|
||||
|
@ -90,12 +92,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_apply_mask() {
|
||||
let mask = [
|
||||
0x6d, 0xb6, 0xb2, 0x80,
|
||||
];
|
||||
let mask = [0x6d, 0xb6, 0xb2, 0x80];
|
||||
let unmasked = vec![
|
||||
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82,
|
||||
0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9, 0x12, 0x03,
|
||||
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9,
|
||||
0x12, 0x03,
|
||||
];
|
||||
|
||||
// Check masking with proper alignment.
|
||||
|
@ -120,6 +120,4 @@ mod tests {
|
|||
assert_eq!(masked, masked_fast);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -2,16 +2,17 @@
|
|||
|
||||
pub mod coding;
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
mod frame;
|
||||
mod mask;
|
||||
|
||||
pub use self::frame::{Frame, FrameHeader};
|
||||
pub use self::frame::CloseFrame;
|
||||
pub use self::frame::{Frame, FrameHeader};
|
||||
|
||||
use std::io::{Read, Write};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use input_buffer::{InputBuffer, MIN_READ};
|
||||
use error::{Error, Result};
|
||||
use log::*;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
/// A reader and writer for WebSocket frames.
|
||||
#[derive(Debug)]
|
||||
|
@ -56,7 +57,8 @@ impl<Stream> FrameSocket<Stream> {
|
|||
}
|
||||
|
||||
impl<Stream> FrameSocket<Stream>
|
||||
where Stream: Read
|
||||
where
|
||||
Stream: Read,
|
||||
{
|
||||
/// Read a frame from stream.
|
||||
pub fn read_frame(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
|
||||
|
@ -65,7 +67,8 @@ impl<Stream> FrameSocket<Stream>
|
|||
}
|
||||
|
||||
impl<Stream> FrameSocket<Stream>
|
||||
where Stream: Write
|
||||
where
|
||||
Stream: Write,
|
||||
{
|
||||
/// Write a frame to stream.
|
||||
///
|
||||
|
@ -138,8 +141,8 @@ impl FrameCodec {
|
|||
// is not too big (fits into `usize`).
|
||||
if length > max_size as u64 {
|
||||
return Err(Error::Capacity(
|
||||
format!("Message length too big: {} > {}", length, max_size).into()
|
||||
))
|
||||
format!("Message length too big: {} > {}", length, max_size).into(),
|
||||
));
|
||||
}
|
||||
|
||||
let input_size = cursor.get_ref().len() as u64 - cursor.position();
|
||||
|
@ -149,19 +152,21 @@ impl FrameCodec {
|
|||
if length > 0 {
|
||||
cursor.take(length).read_to_end(&mut payload)?;
|
||||
}
|
||||
break payload
|
||||
break payload;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not enough data in buffer.
|
||||
let size = self.in_buffer.prepare_reserve(MIN_READ)
|
||||
let size = self
|
||||
.in_buffer
|
||||
.prepare_reserve(MIN_READ)
|
||||
.with_limit(usize::max_value())
|
||||
.map_err(|_| Error::Capacity("Incoming TCP buffer is full".into()))?
|
||||
.read_from(stream)?;
|
||||
if size == 0 {
|
||||
trace!("no frame received");
|
||||
return Ok(None)
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -173,17 +178,15 @@ impl FrameCodec {
|
|||
}
|
||||
|
||||
/// Write a frame to the provided stream.
|
||||
pub(super) fn write_frame<Stream>(
|
||||
&mut self,
|
||||
stream: &mut Stream,
|
||||
frame: Frame,
|
||||
) -> Result<()>
|
||||
pub(super) fn write_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
|
||||
where
|
||||
Stream: Write,
|
||||
{
|
||||
trace!("writing frame {}", frame);
|
||||
self.out_buffer.reserve(frame.len());
|
||||
frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
|
||||
frame
|
||||
.format(&mut self.out_buffer)
|
||||
.expect("Bug: can't write to vector");
|
||||
self.write_pending(stream)
|
||||
}
|
||||
|
||||
|
@ -211,16 +214,19 @@ mod tests {
|
|||
#[test]
|
||||
fn read_frames() {
|
||||
let raw = Cursor::new(vec![
|
||||
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x82, 0x03, 0x03, 0x02, 0x01,
|
||||
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
|
||||
0x99,
|
||||
]);
|
||||
let mut sock = FrameSocket::new(raw);
|
||||
|
||||
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
|
||||
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
|
||||
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
|
||||
vec![0x03, 0x02, 0x01]);
|
||||
assert_eq!(
|
||||
sock.read_frame(None).unwrap().unwrap().into_data(),
|
||||
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
|
||||
);
|
||||
assert_eq!(
|
||||
sock.read_frame(None).unwrap().unwrap().into_data(),
|
||||
vec![0x03, 0x02, 0x01]
|
||||
);
|
||||
assert!(sock.read_frame(None).unwrap().is_none());
|
||||
|
||||
let (_, rest) = sock.into_inner();
|
||||
|
@ -229,12 +235,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn from_partially_read() {
|
||||
let raw = Cursor::new(vec![
|
||||
0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
]);
|
||||
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
|
||||
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
|
||||
assert_eq!(sock.read_frame(None).unwrap().unwrap().into_data(),
|
||||
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
|
||||
assert_eq!(
|
||||
sock.read_frame(None).unwrap().unwrap().into_data(),
|
||||
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -248,17 +254,13 @@ mod tests {
|
|||
sock.write_frame(frame).unwrap();
|
||||
|
||||
let (buf, _) = sock.into_inner();
|
||||
assert_eq!(buf, vec![
|
||||
0x89, 0x02, 0x04, 0x05,
|
||||
0x8a, 0x01, 0x01
|
||||
]);
|
||||
assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_overflow() {
|
||||
let raw = Cursor::new(vec![
|
||||
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
|
||||
0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
|
||||
]);
|
||||
let mut sock = FrameSocket::new(raw);
|
||||
let _ = sock.read_frame(None); // should not crash
|
||||
|
@ -266,11 +268,10 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn size_limit_hit() {
|
||||
let raw = Cursor::new(vec![
|
||||
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
]);
|
||||
let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
|
||||
let mut sock = FrameSocket::new(raw);
|
||||
assert_eq!(sock.read_frame(Some(5)).unwrap_err().to_string(),
|
||||
assert_eq!(
|
||||
sock.read_frame(Some(5)).unwrap_err().to_string(),
|
||||
"Space limit exceeded: Message length too big: 7 > 5"
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
use std::convert::{From, Into, AsRef};
|
||||
use std::convert::{AsRef, From, Into};
|
||||
use std::fmt;
|
||||
use std::result::Result as StdResult;
|
||||
use std::str;
|
||||
|
||||
use error::{Result, Error};
|
||||
use super::frame::CloseFrame;
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
mod string_collect {
|
||||
|
||||
use utf8;
|
||||
use utf8::DecodeError;
|
||||
|
||||
use error::{Error, Result};
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StringCollector {
|
||||
|
@ -28,7 +28,8 @@ mod string_collect {
|
|||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
self.data
|
||||
.len()
|
||||
.saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
|
||||
}
|
||||
|
||||
|
@ -41,7 +42,7 @@ mod string_collect {
|
|||
if let Ok(text) = result {
|
||||
self.data.push_str(text);
|
||||
} else {
|
||||
return Err(Error::Utf8)
|
||||
return Err(Error::Utf8);
|
||||
}
|
||||
true
|
||||
} else {
|
||||
|
@ -59,7 +60,10 @@ mod string_collect {
|
|||
self.data.push_str(text);
|
||||
Ok(())
|
||||
}
|
||||
Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
|
||||
Err(DecodeError::Incomplete {
|
||||
valid_prefix,
|
||||
incomplete_suffix,
|
||||
}) => {
|
||||
self.data.push_str(valid_prefix);
|
||||
self.incomplete = Some(incomplete_suffix);
|
||||
Ok(())
|
||||
|
@ -82,7 +86,6 @@ mod string_collect {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
use self::string_collect::StringCollector;
|
||||
|
@ -104,11 +107,11 @@ impl IncompleteMessage {
|
|||
pub fn new(message_type: IncompleteMessageType) -> Self {
|
||||
IncompleteMessage {
|
||||
collector: match message_type {
|
||||
IncompleteMessageType::Binary =>
|
||||
IncompleteMessageCollector::Binary(Vec::new()),
|
||||
IncompleteMessageType::Text =>
|
||||
IncompleteMessageCollector::Text(StringCollector::new()),
|
||||
}
|
||||
IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
|
||||
IncompleteMessageType::Text => {
|
||||
IncompleteMessageCollector::Text(StringCollector::new())
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -130,8 +133,12 @@ impl IncompleteMessage {
|
|||
// Be careful about integer overflows here.
|
||||
if my_size > max_size || portion_size > max_size - my_size {
|
||||
return Err(Error::Capacity(
|
||||
format!("Message too big: {} + {} > {}", my_size, portion_size, max_size).into()
|
||||
))
|
||||
format!(
|
||||
"Message too big: {} + {} > {}",
|
||||
my_size, portion_size, max_size
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
match self.collector {
|
||||
|
@ -139,18 +146,14 @@ impl IncompleteMessage {
|
|||
v.extend(tail.as_ref());
|
||||
Ok(())
|
||||
}
|
||||
IncompleteMessageCollector::Text(ref mut t) => {
|
||||
t.extend(tail)
|
||||
}
|
||||
IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert an incomplete message into a complete one.
|
||||
pub fn complete(self) -> Result<Message> {
|
||||
match self.collector {
|
||||
IncompleteMessageCollector::Binary(v) => {
|
||||
Ok(Message::Binary(v))
|
||||
}
|
||||
IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)),
|
||||
IncompleteMessageCollector::Text(t) => {
|
||||
let text = t.into_string()?;
|
||||
Ok(Message::Text(text))
|
||||
|
@ -185,17 +188,18 @@ pub enum Message {
|
|||
}
|
||||
|
||||
impl Message {
|
||||
|
||||
/// Create a new text WebSocket message from a stringable.
|
||||
pub fn text<S>(string: S) -> Message
|
||||
where S: Into<String>
|
||||
where
|
||||
S: Into<String>,
|
||||
{
|
||||
Message::Text(string.into())
|
||||
}
|
||||
|
||||
/// Create a new binary WebSocket message by converting to Vec<u8>.
|
||||
pub fn binary<B>(bin: B) -> Message
|
||||
where B: Into<Vec<u8>>
|
||||
where
|
||||
B: Into<Vec<u8>>,
|
||||
{
|
||||
Message::Binary(bin.into())
|
||||
}
|
||||
|
@ -244,9 +248,9 @@ impl Message {
|
|||
pub fn len(&self) -> usize {
|
||||
match *self {
|
||||
Message::Text(ref string) => string.len(),
|
||||
Message::Binary(ref data) |
|
||||
Message::Ping(ref data) |
|
||||
Message::Pong(ref data) => data.len(),
|
||||
Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
|
||||
data.len()
|
||||
}
|
||||
Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
|
||||
}
|
||||
}
|
||||
|
@ -261,9 +265,7 @@ impl Message {
|
|||
pub fn into_data(self) -> Vec<u8> {
|
||||
match self {
|
||||
Message::Text(string) => string.into_bytes(),
|
||||
Message::Binary(data) |
|
||||
Message::Ping(data) |
|
||||
Message::Pong(data) => data,
|
||||
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
|
||||
Message::Close(None) => Vec::new(),
|
||||
Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
|
||||
}
|
||||
|
@ -273,10 +275,9 @@ impl Message {
|
|||
pub fn into_text(self) -> Result<String> {
|
||||
match self {
|
||||
Message::Text(string) => Ok(string),
|
||||
Message::Binary(data) |
|
||||
Message::Ping(data) |
|
||||
Message::Pong(data) => Ok(try!(
|
||||
String::from_utf8(data).map_err(|err| err.utf8_error()))),
|
||||
Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => {
|
||||
Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?)
|
||||
}
|
||||
Message::Close(None) => Ok(String::new()),
|
||||
Message::Close(Some(frame)) => Ok(frame.reason.into_owned()),
|
||||
}
|
||||
|
@ -287,14 +288,13 @@ impl Message {
|
|||
pub fn to_text(&self) -> Result<&str> {
|
||||
match *self {
|
||||
Message::Text(ref string) => Ok(string),
|
||||
Message::Binary(ref data) |
|
||||
Message::Ping(ref data) |
|
||||
Message::Pong(ref data) => Ok(try!(str::from_utf8(data))),
|
||||
Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
|
||||
Ok(str::from_utf8(data)?)
|
||||
}
|
||||
Message::Close(None) => Ok(""),
|
||||
Message::Close(Some(ref frame)) => Ok(&frame.reason),
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl From<String> for Message {
|
||||
|
@ -358,7 +358,6 @@ mod tests {
|
|||
assert!(msg.into_text().is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn binary_convert_vec() {
|
||||
let bin = vec![6u8, 7, 8, 9, 10, 241];
|
||||
|
|
|
@ -4,18 +4,19 @@ pub mod frame;
|
|||
|
||||
mod message;
|
||||
|
||||
pub use self::message::Message;
|
||||
pub use self::frame::CloseFrame;
|
||||
pub use self::message::Message;
|
||||
|
||||
use log::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::{Read, Write, ErrorKind as IoErrorKind};
|
||||
use std::io::{ErrorKind as IoErrorKind, Read, Write};
|
||||
use std::mem::replace;
|
||||
|
||||
use error::{Error, Result};
|
||||
use self::message::{IncompleteMessage, IncompleteMessageType};
|
||||
use self::frame::coding::{CloseCode, Control as OpCtl, Data as OpData, OpCode};
|
||||
use self::frame::{Frame, FrameCodec};
|
||||
use self::frame::coding::{OpCode, Data as OpData, Control as OpCtl, CloseCode};
|
||||
use util::NonBlockingResult;
|
||||
use self::message::{IncompleteMessage, IncompleteMessageType};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::util::NonBlockingResult;
|
||||
|
||||
/// Indicates a Client or Server role of the websocket
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
|
@ -147,7 +148,6 @@ impl<Stream: Read + Write> WebSocket<Stream> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/// A context for managing WebSocket stream.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocketContext {
|
||||
|
@ -182,11 +182,7 @@ impl WebSocketContext {
|
|||
}
|
||||
|
||||
/// Create a WebSocket context that manages an post-handshake stream.
|
||||
pub fn from_partially_read(
|
||||
part: Vec<u8>,
|
||||
role: Role,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Self {
|
||||
pub fn from_partially_read(part: Vec<u8>, role: Role, config: Option<WebSocketConfig>) -> Self {
|
||||
WebSocketContext {
|
||||
frame: FrameCodec::from_partially_read(part),
|
||||
..WebSocketContext::new(role, config)
|
||||
|
@ -217,7 +213,7 @@ impl WebSocketContext {
|
|||
// Thus if read blocks, just let it return WouldBlock.
|
||||
if let Some(message) = self.read_message_frame(stream)? {
|
||||
trace!("Received message {}", message);
|
||||
return Ok(message)
|
||||
return Ok(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -251,20 +247,14 @@ impl WebSocketContext {
|
|||
}
|
||||
|
||||
let frame = match message {
|
||||
Message::Text(data) => {
|
||||
Frame::message(data.into(), OpCode::Data(OpData::Text), true)
|
||||
}
|
||||
Message::Binary(data) => {
|
||||
Frame::message(data, OpCode::Data(OpData::Binary), true)
|
||||
}
|
||||
Message::Text(data) => Frame::message(data.into(), OpCode::Data(OpData::Text), true),
|
||||
Message::Binary(data) => Frame::message(data, OpCode::Data(OpData::Binary), true),
|
||||
Message::Ping(data) => Frame::ping(data),
|
||||
Message::Pong(data) => {
|
||||
self.pong = Some(Frame::pong(data));
|
||||
return self.write_pending(stream)
|
||||
}
|
||||
Message::Close(code) => {
|
||||
return self.close(stream, code)
|
||||
return self.write_pending(stream);
|
||||
}
|
||||
Message::Close(code) => return self.close(stream, code),
|
||||
};
|
||||
|
||||
self.send_queue.push_back(frame);
|
||||
|
@ -342,7 +332,6 @@ impl WebSocketContext {
|
|||
Stream: Read + Write,
|
||||
{
|
||||
if let Some(mut frame) = self.frame.read_frame(stream, self.config.max_frame_size)? {
|
||||
|
||||
// MUST be 0 unless an extension is negotiated that defines meanings
|
||||
// for non-zero values. If a nonzero value is received and none of
|
||||
// the negotiated extensions defines the meaning of such a nonzero
|
||||
|
@ -351,7 +340,7 @@ impl WebSocketContext {
|
|||
{
|
||||
let hdr = frame.header();
|
||||
if hdr.rsv1 || hdr.rsv2 || hdr.rsv3 {
|
||||
return Err(Error::Protocol("Reserved bits are non-zero".into()))
|
||||
return Err(Error::Protocol("Reserved bits are non-zero".into()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -364,19 +353,22 @@ impl WebSocketContext {
|
|||
} else {
|
||||
// The server MUST close the connection upon receiving a
|
||||
// frame that is not masked. (RFC 6455)
|
||||
return Err(Error::Protocol("Received an unmasked frame from client".into()))
|
||||
return Err(Error::Protocol(
|
||||
"Received an unmasked frame from client".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Role::Client => {
|
||||
if frame.is_masked() {
|
||||
// A client MUST close a connection if it detects a masked frame. (RFC 6455)
|
||||
return Err(Error::Protocol("Received a masked frame from server".into()))
|
||||
return Err(Error::Protocol(
|
||||
"Received a masked frame from server".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match frame.header().opcode {
|
||||
|
||||
OpCode::Control(ctl) => {
|
||||
match ctl {
|
||||
// All control frames MUST have a payload length of 125 bytes or less
|
||||
|
@ -387,12 +379,10 @@ impl WebSocketContext {
|
|||
_ if frame.payload().len() > 125 => {
|
||||
Err(Error::Protocol("Control frame too big".into()))
|
||||
}
|
||||
OpCtl::Close => {
|
||||
Ok(self.do_close(frame.into_close()?).map(Message::Close))
|
||||
}
|
||||
OpCtl::Reserved(i) => {
|
||||
Err(Error::Protocol(format!("Unknown control frame type {}", i).into()))
|
||||
}
|
||||
OpCtl::Close => Ok(self.do_close(frame.into_close()?).map(Message::Close)),
|
||||
OpCtl::Reserved(i) => Err(Error::Protocol(
|
||||
format!("Unknown control frame type {}", i).into(),
|
||||
)),
|
||||
OpCtl::Ping | OpCtl::Pong if !self.state.is_active() => {
|
||||
// No ping processing while closing.
|
||||
Ok(None)
|
||||
|
@ -402,9 +392,7 @@ impl WebSocketContext {
|
|||
self.pong = Some(Frame::pong(data.clone()));
|
||||
Ok(Some(Message::Ping(data)))
|
||||
}
|
||||
OpCtl::Pong => {
|
||||
Ok(Some(Message::Pong(frame.into_data())))
|
||||
}
|
||||
OpCtl::Pong => Ok(Some(Message::Pong(frame.into_data()))),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -420,7 +408,9 @@ impl WebSocketContext {
|
|||
if let Some(ref mut msg) = self.incomplete {
|
||||
msg.extend(frame.into_data(), self.config.max_message_size)?;
|
||||
} else {
|
||||
return Err(Error::Protocol("Continue frame but nothing to continue".into()))
|
||||
return Err(Error::Protocol(
|
||||
"Continue frame but nothing to continue".into(),
|
||||
));
|
||||
}
|
||||
if fin {
|
||||
Ok(Some(self.incomplete.take().unwrap().complete()?))
|
||||
|
@ -428,11 +418,9 @@ impl WebSocketContext {
|
|||
Ok(None)
|
||||
}
|
||||
}
|
||||
c if self.incomplete.is_some() => {
|
||||
Err(Error::Protocol(
|
||||
format!("Received {} while waiting for more fragments", c).into()
|
||||
))
|
||||
}
|
||||
c if self.incomplete.is_some() => Err(Error::Protocol(
|
||||
format!("Received {} while waiting for more fragments", c).into(),
|
||||
)),
|
||||
OpData::Text | OpData::Binary => {
|
||||
let msg = {
|
||||
let message_type = match data {
|
||||
|
@ -451,28 +439,27 @@ impl WebSocketContext {
|
|||
Ok(None)
|
||||
}
|
||||
}
|
||||
OpData::Reserved(i) => {
|
||||
Err(Error::Protocol(format!("Unknown data frame type {}", i).into()))
|
||||
}
|
||||
OpData::Reserved(i) => Err(Error::Protocol(
|
||||
format!("Unknown data frame type {}", i).into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
} // match opcode
|
||||
|
||||
} else {
|
||||
// Connection closed by peer
|
||||
match replace(&mut self.state, WebSocketState::Terminated) {
|
||||
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged => {
|
||||
Err(Error::ConnectionClosed)
|
||||
}
|
||||
_ => {
|
||||
Err(Error::Protocol("Connection reset without closing handshake".into()))
|
||||
}
|
||||
_ => Err(Error::Protocol(
|
||||
"Connection reset without closing handshake".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Received a close frame. Tells if we need to return a close frame to the user.
|
||||
#[allow(clippy::option_option)]
|
||||
fn do_close<'t>(&mut self, close: Option<CloseFrame<'t>>) -> Option<Option<CloseFrame<'t>>> {
|
||||
debug!("Received close frame: {:?}", close);
|
||||
match self.state {
|
||||
|
@ -488,7 +475,7 @@ impl WebSocketContext {
|
|||
} else {
|
||||
Frame::close(Some(CloseFrame {
|
||||
code: CloseCode::Protocol,
|
||||
reason: "Protocol violation".into()
|
||||
reason: "Protocol violation".into(),
|
||||
}))
|
||||
}
|
||||
} else {
|
||||
|
@ -518,8 +505,7 @@ impl WebSocketContext {
|
|||
Stream: Read + Write,
|
||||
{
|
||||
match self.role {
|
||||
Role::Server => {
|
||||
}
|
||||
Role::Server => {}
|
||||
Role::Client => {
|
||||
// 5. If the data is being sent by the client, the frame(s) MUST be
|
||||
// masked as defined in Section 5.3. (RFC 6455)
|
||||
|
@ -535,7 +521,9 @@ impl WebSocketContext {
|
|||
match self.state {
|
||||
WebSocketState::ClosedByPeer | WebSocketState::CloseAcknowledged
|
||||
if err.kind() == IoErrorKind::ConnectionReset =>
|
||||
Error::ConnectionClosed,
|
||||
{
|
||||
Error::ConnectionClosed
|
||||
}
|
||||
_ => Error::Io(err),
|
||||
}
|
||||
}),
|
||||
|
@ -544,7 +532,6 @@ impl WebSocketContext {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/// The current connection state.
|
||||
#[derive(Debug)]
|
||||
enum WebSocketState {
|
||||
|
@ -580,7 +567,7 @@ impl WebSocketState {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{WebSocket, Role, Message, WebSocketConfig};
|
||||
use super::{Message, Role, WebSocket, WebSocketConfig};
|
||||
|
||||
use std::io;
|
||||
use std::io::Cursor;
|
||||
|
@ -602,57 +589,53 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn receive_messages() {
|
||||
let incoming = Cursor::new(vec![
|
||||
0x89, 0x02, 0x01, 0x02,
|
||||
0x8a, 0x01, 0x03,
|
||||
0x01, 0x07,
|
||||
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
|
||||
0x80, 0x06,
|
||||
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
|
||||
0x82, 0x03,
|
||||
0x01, 0x02, 0x03,
|
||||
0x89, 0x02, 0x01, 0x02, 0x8a, 0x01, 0x03, 0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f,
|
||||
0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x82, 0x03, 0x01, 0x02,
|
||||
0x03,
|
||||
]);
|
||||
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, None);
|
||||
assert_eq!(socket.read_message().unwrap(), Message::Ping(vec![1, 2]));
|
||||
assert_eq!(socket.read_message().unwrap(), Message::Pong(vec![3]));
|
||||
assert_eq!(socket.read_message().unwrap(), Message::Text("Hello, World!".into()));
|
||||
assert_eq!(socket.read_message().unwrap(), Message::Binary(vec![0x01, 0x02, 0x03]));
|
||||
assert_eq!(
|
||||
socket.read_message().unwrap(),
|
||||
Message::Text("Hello, World!".into())
|
||||
);
|
||||
assert_eq!(
|
||||
socket.read_message().unwrap(),
|
||||
Message::Binary(vec![0x01, 0x02, 0x03])
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn size_limiting_text_fragmented() {
|
||||
let incoming = Cursor::new(vec![
|
||||
0x01, 0x07,
|
||||
0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20,
|
||||
0x80, 0x06,
|
||||
0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21,
|
||||
0x01, 0x07, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x80, 0x06, 0x57, 0x6f, 0x72,
|
||||
0x6c, 0x64, 0x21,
|
||||
]);
|
||||
let limit = WebSocketConfig {
|
||||
max_message_size: Some(10),
|
||||
.. WebSocketConfig::default()
|
||||
..WebSocketConfig::default()
|
||||
};
|
||||
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
|
||||
assert_eq!(socket.read_message().unwrap_err().to_string(),
|
||||
assert_eq!(
|
||||
socket.read_message().unwrap_err().to_string(),
|
||||
"Space limit exceeded: Message too big: 7 + 6 > 10"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn size_limiting_binary() {
|
||||
let incoming = Cursor::new(vec![
|
||||
0x82, 0x03,
|
||||
0x01, 0x02, 0x03,
|
||||
]);
|
||||
let incoming = Cursor::new(vec![0x82, 0x03, 0x01, 0x02, 0x03]);
|
||||
let limit = WebSocketConfig {
|
||||
max_message_size: Some(2),
|
||||
.. WebSocketConfig::default()
|
||||
..WebSocketConfig::default()
|
||||
};
|
||||
let mut socket = WebSocket::from_raw_socket(WriteMoc(incoming), Role::Client, Some(limit));
|
||||
assert_eq!(socket.read_message().unwrap_err().to_string(),
|
||||
assert_eq!(
|
||||
socket.read_message().unwrap_err().to_string(),
|
||||
"Space limit exceeded: Message too big: 0 + 3 > 2"
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
//! Methods to accept an incoming WebSocket connection on a server.
|
||||
|
||||
pub use handshake::server::ServerHandshake;
|
||||
pub use crate::handshake::server::ServerHandshake;
|
||||
|
||||
use handshake::HandshakeError;
|
||||
use handshake::server::{Callback, NoCallback};
|
||||
use crate::handshake::server::{Callback, NoCallback};
|
||||
use crate::handshake::HandshakeError;
|
||||
|
||||
use protocol::{WebSocket, WebSocketConfig};
|
||||
use crate::protocol::{WebSocket, WebSocketConfig};
|
||||
|
||||
use std::io::{Read, Write};
|
||||
|
||||
|
@ -18,9 +18,10 @@ use std::io::{Read, Write};
|
|||
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
|
||||
/// for the stream here. Any `Read + Write` streams are supported, including
|
||||
/// those from `Mio` and others.
|
||||
pub fn accept_with_config<S: Read + Write>(stream: S, config: Option<WebSocketConfig>)
|
||||
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
|
||||
{
|
||||
pub fn accept_with_config<S: Read + Write>(
|
||||
stream: S,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
|
||||
accept_hdr_with_config(stream, NoCallback, config)
|
||||
}
|
||||
|
||||
|
@ -30,9 +31,9 @@ pub fn accept_with_config<S: Read + Write>(stream: S, config: Option<WebSocketCo
|
|||
/// If you want TLS support, use `native_tls::TlsStream` or `openssl::ssl::SslStream`
|
||||
/// for the stream here. Any `Read + Write` streams are supported, including
|
||||
/// those from `Mio` and others.
|
||||
pub fn accept<S: Read + Write>(stream: S)
|
||||
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>>
|
||||
{
|
||||
pub fn accept<S: Read + Write>(
|
||||
stream: S,
|
||||
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, NoCallback>>> {
|
||||
accept_with_config(stream, None)
|
||||
}
|
||||
|
||||
|
@ -47,7 +48,7 @@ pub fn accept<S: Read + Write>(stream: S)
|
|||
pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
|
||||
stream: S,
|
||||
callback: C,
|
||||
config: Option<WebSocketConfig>
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
|
||||
ServerHandshake::start(stream, callback, config).handshake()
|
||||
}
|
||||
|
@ -57,8 +58,9 @@ pub fn accept_hdr_with_config<S: Read + Write, C: Callback>(
|
|||
/// This function does the same as `accept()` but accepts an extra callback
|
||||
/// for header processing. The callback receives headers of the incoming
|
||||
/// requests and is able to add extra headers to the reply.
|
||||
pub fn accept_hdr<S: Read + Write, C: Callback>(stream: S, callback: C)
|
||||
-> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>>
|
||||
{
|
||||
pub fn accept_hdr<S: Read + Write, C: Callback>(
|
||||
stream: S,
|
||||
callback: C,
|
||||
) -> Result<WebSocket<S>, HandshakeError<ServerHandshake<S, C>>> {
|
||||
accept_hdr_with_config(stream, callback, None)
|
||||
}
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
|
||||
//! `Read + Write` traits.
|
||||
|
||||
use std::io::{Read, Write, Result as IoResult};
|
||||
use std::io::{Read, Result as IoResult, Write};
|
||||
|
||||
use std::net::TcpStream;
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
#[cfg(feature = "tls")]
|
||||
use native_tls::TlsStream;
|
||||
|
||||
/// Stream mode, either plain TCP or TLS.
|
||||
|
@ -32,7 +32,7 @@ impl NoDelay for TcpStream {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
#[cfg(feature = "tls")]
|
||||
impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
|
||||
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
|
||||
self.get_mut().set_nodelay(nodelay)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
|
||||
use std::result::Result as StdResult;
|
||||
|
||||
use error::Error;
|
||||
use crate::error::Error;
|
||||
|
||||
/// Non-blocking IO handling.
|
||||
pub trait NonBlockingError: Sized {
|
||||
|
@ -40,7 +40,8 @@ pub trait NonBlockingResult {
|
|||
}
|
||||
|
||||
impl<T, E> NonBlockingResult for StdResult<T, E>
|
||||
where E : NonBlockingError
|
||||
where
|
||||
E: NonBlockingError,
|
||||
{
|
||||
type Result = StdResult<Option<T>, E>;
|
||||
fn no_block(self) -> Self::Result {
|
||||
|
@ -49,7 +50,7 @@ impl<T, E> NonBlockingResult for StdResult<T, E>
|
|||
Err(e) => match e.into_non_blocking() {
|
||||
Some(e) => Err(e),
|
||||
None => Ok(None),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
//! Verifies that the server returns a `ConnectionClosed` error when the connection
|
||||
//! is closedd from the server's point of view and drop the underlying tcp socket.
|
||||
|
||||
extern crate env_logger;
|
||||
extern crate tungstenite;
|
||||
extern crate url;
|
||||
|
||||
use std::net::TcpListener;
|
||||
use std::process::exit;
|
||||
use std::thread::{spawn, sleep};
|
||||
use std::thread::{sleep, spawn};
|
||||
use std::time::Duration;
|
||||
|
||||
use tungstenite::{accept, connect, Error, Message};
|
||||
|
@ -28,14 +24,16 @@ fn test_close() {
|
|||
let client_thread = spawn(move || {
|
||||
let (mut client, _) = connect(Url::parse("ws://localhost:3012/socket").unwrap()).unwrap();
|
||||
|
||||
client.write_message(Message::Text("Hello WebSocket".into())).unwrap();
|
||||
client
|
||||
.write_message(Message::Text("Hello WebSocket".into()))
|
||||
.unwrap();
|
||||
|
||||
let message = client.read_message().unwrap(); // receive close from server
|
||||
assert!(message.is_close());
|
||||
|
||||
let err = client.read_message().unwrap_err(); // now we should get ConnectionClosed
|
||||
match err {
|
||||
Error::ConnectionClosed => { },
|
||||
Error::ConnectionClosed => {}
|
||||
_ => panic!("unexpected error"),
|
||||
}
|
||||
});
|
||||
|
@ -52,7 +50,7 @@ fn test_close() {
|
|||
|
||||
let err = client_handler.read_message().unwrap_err(); // now we should get ConnectionClosed
|
||||
match err {
|
||||
Error::ConnectionClosed => { },
|
||||
Error::ConnectionClosed => {}
|
||||
_ => panic!("unexpected error"),
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue