refactor: make handshake completely async

This commit is contained in:
Alexey Galakhov 2017-03-08 10:50:13 +01:00
parent 334ceab2b0
commit b7557f1baa
13 changed files with 553 additions and 457 deletions

View File

@ -5,9 +5,7 @@ extern crate url;
use url::Url;
use tungstenite::protocol::Message;
use tungstenite::client::connect;
use tungstenite::handshake::Handshake;
use tungstenite::error::{Error, Result};
const AGENT: &'static str = "Tungstenite";
@ -15,17 +13,17 @@ const AGENT: &'static str = "Tungstenite";
fn get_case_count() -> Result<u32> {
let mut socket = connect(
Url::parse("ws://localhost:9001/getCaseCount").unwrap()
)?.handshake_wait()?;
)?;
let msg = socket.read_message()?;
socket.close();
socket.close()?;
Ok(msg.into_text()?.parse::<u32>().unwrap())
}
fn update_reports() -> Result<()> {
let mut socket = connect(
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
)?.handshake_wait()?;
socket.close();
)?;
socket.close()?;
Ok(())
}
@ -34,13 +32,11 @@ fn run_test(case: u32) -> Result<()> {
let case_url = Url::parse(
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
).unwrap();
let mut socket = connect(case_url)?.handshake_wait()?;
let mut socket = connect(case_url)?;
loop {
let msg = socket.read_message()?;
socket.write_message(msg)?;
}
socket.close();
Ok(())
}
fn main() {

View File

@ -6,11 +6,18 @@ use std::net::{TcpListener, TcpStream};
use std::thread::spawn;
use tungstenite::server::accept;
use tungstenite::error::Result;
use tungstenite::handshake::Handshake;
use tungstenite::handshake::HandshakeError;
use tungstenite::error::{Error, Result};
fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
match err {
HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"),
HandshakeError::Failure(f) => f,
}
}
fn handle_client(stream: TcpStream) -> Result<()> {
let mut socket = accept(stream).handshake_wait()?;
let mut socket = accept(stream).map_err(must_not_block)?;
loop {
let msg = socket.read_message()?;
socket.write_message(msg)?;

View File

@ -5,15 +5,12 @@ extern crate env_logger;
use url::Url;
use tungstenite::protocol::Message;
use tungstenite::client::connect;
use tungstenite::handshake::Handshake;
fn main() {
env_logger::init().unwrap();
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap())
.expect("Can't connect")
.handshake_wait()
.expect("Handshake error");
.expect("Can't connect");
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
loop {

View File

@ -1,75 +1,97 @@
use std::net::{TcpStream, ToSocketAddrs};
use url::{Url, SocketAddrs};
use std::net::{TcpStream, SocketAddr, ToSocketAddrs};
use std::result::Result as StdResult;
use std::io::{Read, Write};
use url::Url;
#[cfg(feature="tls")]
use native_tls::{TlsStream, TlsConnector, HandshakeError as TlsHandshakeError};
use protocol::WebSocket;
use handshake::{Handshake as HandshakeTrait, HandshakeResult};
use handshake::HandshakeError;
use handshake::client::{ClientHandshake, Request};
use stream::Mode;
use error::{Error, Result};
/// Connect to the given WebSocket.
#[cfg(feature="tls")]
use stream::Stream as StreamSwitcher;
#[cfg(feature="tls")]
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
#[cfg(not(feature="tls"))]
pub type AutoStream = TcpStream;
/// Connect to the given WebSocket in blocking mode.
///
/// Note that this function may block the current thread while DNS resolution is performed.
pub fn connect(url: Url) -> Result<Handshake> {
let mode = match url.scheme() {
"ws" => Mode::Plain,
#[cfg(feature="tls")]
"wss" => Mode::Tls,
_ => return Err(Error::Url("URL scheme not supported".into()))
};
// Note that this function may block the current thread while resolution is performed.
/// The URL may be either ws:// or wss://.
/// To support wss:// URLs, feature "tls" must be turned on.
pub fn connect(url: Url) -> Result<WebSocket<AutoStream>> {
let mode = url_mode(&url)?;
let addrs = url.to_socket_addrs()?;
Ok(Handshake {
state: HandshakeState::Nothing(url),
alt_addresses: addrs,
})
let stream = connect_to_some(addrs, &url, mode)?;
client(url.clone(), stream)
.map_err(|e| match e {
HandshakeError::Failure(f) => f,
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
})
}
enum Mode {
Plain,
Tls,
}
enum HandshakeState {
Nothing(Url),
WebSocket(ClientHandshake<TcpStream>),
}
pub struct Handshake {
state: HandshakeState,
alt_addresses: SocketAddrs,
}
impl HandshakeTrait for Handshake {
type Stream = WebSocket<TcpStream>;
fn handshake(mut self) -> Result<HandshakeResult<Self>> {
match self.state {
HandshakeState::Nothing(url) => {
if let Some(addr) = self.alt_addresses.next() {
debug!("Trying to contact {} at {}...", url, addr);
let state = {
if let Ok(stream) = TcpStream::connect(addr) {
let hs = ClientHandshake::new(stream, Request { url: url });
HandshakeState::WebSocket(hs)
} else {
HandshakeState::Nothing(url)
}
};
Ok(HandshakeResult::Incomplete(Handshake {
state: state,
..self
}))
} else {
Err(Error::Url(format!("Unable to resolve {}", url).into()))
}
}
HandshakeState::WebSocket(ws) => {
let alt_addresses = self.alt_addresses;
ws.handshake().map(move |r| r.map(move |s| Handshake {
state: HandshakeState::WebSocket(s),
alt_addresses: alt_addresses,
}))
}
#[cfg(feature="tls")]
fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
Mode::Tls => {
let connector = TlsConnector::builder()?.build()?;
connector.connect(domain, stream)
.map_err(|e| match e {
TlsHandshakeError::Failure(f) => f.into(),
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
})
.map(|s| StreamSwitcher::Tls(s))
}
}
}
#[cfg(not(feature="tls"))]
fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
match mode {
Mode::Plain => Ok(stream),
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
}
}
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
where A: Iterator<Item=SocketAddr>
{
let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream)
}
}
}
Err(Error::Url(format!("Unable to connect to {}", url).into()))
}
/// Get the mode of the given URL.
///
/// This function may be used in non-blocking implementations.
pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into()))
}
}
/// Do the client handshake over the given stream.
///
/// Use this function if you need a nonblocking handshake support.
pub fn client<Stream: Read + Write>(url: Url, stream: Stream)
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>>
{
let request = Request { url: url };
ClientHandshake::start(stream, request).handshake()
}

View File

@ -11,6 +11,9 @@ use std::string;
use httparse;
#[cfg(feature="tls")]
use native_tls;
pub type Result<T> = result::Result<T, Error>;
/// Possible WebSocket errors
@ -20,6 +23,9 @@ pub enum Error {
ConnectionClosed,
/// Input-output error
Io(io::Error),
#[cfg(feature="tls")]
/// TLS error
Tls(native_tls::Error),
/// Buffer capacity exhausted
Capacity(Cow<'static, str>),
/// Protocol violation
@ -37,6 +43,8 @@ impl fmt::Display for Error {
match *self {
Error::ConnectionClosed => write!(f, "Connection closed"),
Error::Io(ref err) => write!(f, "IO error: {}", err),
#[cfg(feature="tls")]
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
@ -51,6 +59,8 @@ impl ErrorTrait for Error {
match *self {
Error::ConnectionClosed => "",
Error::Io(ref err) => err.description(),
#[cfg(feature="tls")]
Error::Tls(ref err) => err.description(),
Error::Capacity(ref msg) => msg.borrow(),
Error::Protocol(ref msg) => msg.borrow(),
Error::Utf8 => "",
@ -78,6 +88,13 @@ impl From<string::FromUtf8Error> for Error {
}
}
#[cfg(feature="tls")]
impl From<native_tls::Error> for Error {
fn from(err: native_tls::Error) -> Self {
Error::Tls(err)
}
}
impl From<httparse::Error> for Error {
fn from(err: httparse::Error) -> Self {
match err {

View File

@ -1,25 +1,16 @@
use std::io::{Read, Write, Cursor};
use base64;
use rand;
use bytes::Buf;
use httparse;
use httparse::Status;
use std::io::Write;
use url::Url;
use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result};
use protocol::{
WebSocket, Role,
};
use super::{
Headers,
Httparse, FromHttparse,
Handshake, HandshakeResult,
convert_key,
MAX_HEADERS,
};
use util::NonBlockingResult;
use protocol::{WebSocket, Role};
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
use super::machine::{HandshakeMachine, StageResult, TryParse};
/// Client request.
pub struct Request {
@ -47,77 +38,59 @@ impl Request {
}
}
/// Client handshake.
pub struct ClientHandshake<Stream> {
stream: Stream,
state: HandshakeState,
/// Client handshake role.
pub struct ClientHandshake {
verify_data: VerifyData,
}
impl<Stream: Read + Write> ClientHandshake<Stream> {
/// Initiate a WebSocket handshake over the given stream.
pub fn new(stream: Stream, request: Request) -> Self {
impl ClientHandshake {
/// Initiate a client handshake.
pub fn start<Stream>(stream: Stream, request: Request) -> MidHandshake<Stream, Self> {
let key = generate_key();
let mut req = Vec::new();
write!(req, "\
GET {path} HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
\r\n", host = request.get_host(), path = request.get_path(), key = key)
.unwrap();
let machine = {
let mut req = Vec::new();
write!(req, "\
GET {path} HTTP/1.1\r\n\
Host: {host}\r\n\
Connection: upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: {key}\r\n\
\r\n", host = request.get_host(), path = request.get_path(), key = key)
.unwrap();
HandshakeMachine::start_write(stream, req)
};
let accept_key = convert_key(key.as_ref()).unwrap();
let client = {
let accept_key = convert_key(key.as_ref()).unwrap();
ClientHandshake {
verify_data: VerifyData {
accept_key: accept_key,
},
}
};
ClientHandshake {
stream: stream,
state: HandshakeState::SendingRequest(Cursor::new(req)),
verify_data: VerifyData {
accept_key: accept_key,
},
}
debug!("Client handshake initiated.");
MidHandshake { role: client, machine: machine }
}
}
impl<Stream: Read + Write> Handshake for ClientHandshake<Stream> {
type Stream = WebSocket<Stream>;
fn handshake(mut self) -> Result<HandshakeResult<Self>> {
debug!("Performing client handshake...");
match self.state {
HandshakeState::SendingRequest(mut req) => {
let size = self.stream.write(Buf::bytes(&req)).no_block()?.unwrap_or(0);
Buf::advance(&mut req, size);
let state = if req.has_remaining() {
HandshakeState::SendingRequest(req)
} else {
HandshakeState::ReceivingResponse(InputBuffer::with_capacity(MIN_READ))
};
Ok(HandshakeResult::Incomplete(ClientHandshake {
state: state,
..self
}))
impl HandshakeRole for ClientHandshake {
type IncomingData = Response;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>>
{
Ok(match finish {
StageResult::DoneWriting(stream) => {
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
}
HandshakeState::ReceivingResponse(mut resp_buf) => {
resp_buf.reserve(MIN_READ, usize::max_value())
.map_err(|_| Error::Capacity("Header too long".into()))?;
resp_buf.read_from(&mut self.stream).no_block()?;
if let Some(resp) = Response::parse(&mut resp_buf)? {
self.verify_data.verify_response(&resp)?;
let ws = WebSocket::from_partially_read(self.stream,
resp_buf.into_vec(), Role::Client);
debug!("Client handshake done.");
Ok(HandshakeResult::Done(ws))
} else {
Ok(HandshakeResult::Incomplete(ClientHandshake {
state: HandshakeState::ReceivingResponse(resp_buf),
..self
}))
}
StageResult::DoneReading { stream, result, tail, } => {
self.verify_data.verify_response(&result)?;
debug!("Client handshake done.");
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client))
}
}
})
}
}
@ -173,27 +146,14 @@ impl VerifyData {
}
}
/// Internal state of the client handshake.
enum HandshakeState {
SendingRequest(Cursor<Vec<u8>>),
ReceivingResponse(InputBuffer),
}
/// Server response.
pub struct Response {
code: u16,
headers: Headers,
}
impl Response {
/// Parse the response from a stream.
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Response::parse_http(input)
}
}
impl Httparse for Response {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
impl TryParse for Response {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Response::new(&mut hbuffer);
Ok(match req.parse(buf)? {

148
src/handshake/headers.rs Normal file
View File

@ -0,0 +1,148 @@
use std::ascii::AsciiExt;
use std::str::from_utf8;
use std::slice;
use httparse;
use error::Result;
// Limit the number of header lines.
pub const MAX_HEADERS: usize = 124;
/// HTTP request or response headers.
#[derive(Debug)]
pub struct Headers {
data: Vec<(String, Box<[u8]>)>,
}
impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next()
}
/// Iterate over all headers with the given name.
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter {
name: name,
iter: self.data.iter()
}
}
/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
}
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name).ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
}
}
/// The iterator over headers.
pub struct HeadersIter<'name, 'headers> {
name: &'name str,
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
}
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
type Item = &'headers [u8];
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value)
}
}
None
}
}
/// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized {
fn from_httparse(raw: T) -> Result<Self>;
}
/*
impl TryParse for Headers {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
Status::Partial => None,
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
})
}
}*/
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
})
}
}
#[cfg(test)]
mod tests {
use super::Headers;
use std::io::Cursor;
#[test]
fn headers() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));
assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
}
#[test]
fn headers_iter() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
Upgrade: websocket\r\n\
\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
let mut iter = hdr.find("Sec-WebSocket-Extensions");
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..]));
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..]));
assert_eq!(iter.next(), None);
}
#[test]
fn headers_incomplete() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap();
assert!(hdr.is_none());
}
}

119
src/handshake/machine.rs Normal file
View File

@ -0,0 +1,119 @@
use std::io::{Cursor, Read, Write};
use bytes::Buf;
use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result};
use util::NonBlockingResult;
/// A generic handshake state machine.
pub struct HandshakeMachine<Stream> {
stream: Stream,
state: HandshakeState,
}
impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self {
HandshakeMachine {
stream: stream,
state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)),
}
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
HandshakeMachine {
stream: stream,
state: HandshakeState::Writing(Cursor::new(data.into())),
}
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
&self.stream
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
&mut self.stream
}
}
impl<Stream: Read + Write> HandshakeMachine<Stream> {
/// Perform a single handshake round.
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
Ok(match self.state {
HandshakeState::Reading(mut buf) => {
buf.reserve(MIN_READ, usize::max_value()) // TODO limit size
.map_err(|_| Error::Capacity("Header too long".into()))?;
if let Some(_) = buf.read_from(&mut self.stream).no_block()? {
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}
} else {
RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf),
..self
})
}
}
HandshakeState::Writing(mut buf) => {
if let Some(size) = self.stream.write(Buf::bytes(&buf)).no_block()? {
buf.advance(size);
if buf.has_remaining() {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Writing(buf),
..self
})
} else {
RoundResult::StageFinished(StageResult::DoneWriting(self.stream))
}
} else {
RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Writing(buf),
..self
})
}
}
})
}
}
/// The result of the round.
pub enum RoundResult<Obj, Stream> {
/// Round not done, I/O would block.
WouldBlock(HandshakeMachine<Stream>),
/// Round done, state unchanged.
Incomplete(HandshakeMachine<Stream>),
/// Stage complete.
StageFinished(StageResult<Obj, Stream>),
}
/// The result of the stage.
pub enum StageResult<Obj, Stream> {
/// Reading round finished.
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
/// Writing round finished.
DoneWriting(Stream),
}
/// The parseable object.
pub trait TryParse: Sized {
/// Return Ok(None) if incomplete, Err on syntax error.
fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
}
/// The handshake state.
enum HandshakeState {
/// Reading data from the peer.
Reading(InputBuffer),
/// Sending data to the peer.
Writing(Cursor<Vec<u8>>),
}

View File

@ -1,63 +1,89 @@
pub mod headers;
pub mod client;
pub mod server;
#[cfg(feature="tls")]
pub mod tls;
use std::ascii::AsciiExt;
use std::str::from_utf8;
use std::slice;
mod machine;
use std::io::{Read, Write};
use base64;
use bytes::Buf;
use httparse;
use httparse::Status;
use sha1::Sha1;
use error::Result;
use error::Error;
use protocol::WebSocket;
// Limit the number of header lines.
const MAX_HEADERS: usize = 124;
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
/// A handshake state.
pub trait Handshake: Sized {
/// Resulting stream of this handshake.
type Stream;
/// Perform a single handshake round.
fn handshake(self) -> Result<HandshakeResult<Self>>;
/// Perform handshake to the end in a blocking mode.
fn handshake_wait(self) -> Result<Self::Stream> {
let mut hs = self;
/// A WebSocket handshake.
pub struct MidHandshake<Stream, Role> {
role: Role,
machine: HandshakeMachine<Stream>,
}
impl<Stream, Role> MidHandshake<Stream, Role> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
self.machine.get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
self.machine.get_mut()
}
}
impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
/// Restarts the handshake process.
pub fn handshake(self) -> Result<WebSocket<Stream>, HandshakeError<Stream, Role>> {
let mut mach = self.machine;
loop {
hs = match hs.handshake()? {
HandshakeResult::Done(stream) => return Ok(stream),
HandshakeResult::Incomplete(s) => s,
mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
}
RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => {
match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(ws) => return Ok(ws),
}
}
}
}
}
}
/// A handshake result.
pub enum HandshakeResult<H: Handshake> {
/// Handshake is done, a WebSocket stream is ready.
Done(H::Stream),
/// Handshake is not done, call handshake() again.
Incomplete(H),
pub enum HandshakeError<Stream, Role> {
/// Handshake was interrupted (would block).
Interrupted(MidHandshake<Stream, Role>),
/// Handshake failed.
Failure(Error),
}
impl<H: Handshake> HandshakeResult<H> {
pub fn map<R, F>(self, func: F) -> HandshakeResult<R>
where R: Handshake<Stream = H::Stream>,
F: FnOnce(H) -> R,
{
match self {
HandshakeResult::Done(s) => HandshakeResult::Done(s),
HandshakeResult::Incomplete(h) => HandshakeResult::Incomplete(func(h)),
}
impl<Stream, Role> From<Error> for HandshakeError<Stream, Role> {
fn from(err: Error) -> Self {
HandshakeError::Failure(err)
}
}
/// Handshake role.
pub trait HandshakeRole {
#[doc(hidden)]
type IncomingData: TryParse;
#[doc(hidden)]
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>, Error>;
}
/// Stage processing result.
#[doc(hidden)]
pub enum ProcessingResult<Stream> {
Continue(HandshakeMachine<Stream>),
Done(WebSocket<Stream>),
}
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
fn convert_key(input: &[u8]) -> Result<String> {
fn convert_key(input: &[u8]) -> Result<String, Error> {
// ... field is constructed by concatenating /key/ ...
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@ -67,113 +93,10 @@ fn convert_key(input: &[u8]) -> Result<String> {
Ok(base64::encode(&sha1.digest().bytes()))
}
/// HTTP request or response headers.
#[derive(Debug)]
pub struct Headers {
data: Vec<(String, Box<[u8]>)>,
}
impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next()
}
/// Iterate over all headers with the given name.
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter {
name: name,
iter: self.data.iter()
}
}
/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
}
/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name).ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
}
/// Try to parse data and return headers, if any.
fn parse<B: Buf>(input: &mut B) -> Result<Option<Headers>> {
Headers::parse_http(input)
}
}
/// The iterator over headers.
pub struct HeadersIter<'name, 'headers> {
name: &'name str,
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
}
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
type Item = &'headers [u8];
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value)
}
}
None
}
}
/// Trait to read HTTP parseable objects.
trait Httparse: Sized {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>>;
fn parse_http<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Ok(match Self::httparse(input.bytes())? {
Some((size, obj)) => {
input.advance(size);
Some(obj)
},
None => None,
})
}
}
/// Trait to convert raw objects into HTTP parseables.
trait FromHttparse<T>: Sized {
fn from_httparse(raw: T) -> Result<Self>;
}
impl Httparse for Headers {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
Status::Partial => None,
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
})
}
}
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
})
}
}
#[cfg(test)]
mod tests {
use super::{Headers, convert_key};
use std::io::Cursor;
use super::convert_key;
#[test]
fn key_conversion() {
@ -182,50 +105,4 @@ mod tests {
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
#[test]
fn headers() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));
assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
}
#[test]
fn headers_iter() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Sec-WebSocket-Extensions: permessage-deflate\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
Upgrade: websocket\r\n\
\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
let mut iter = hdr.find("Sec-WebSocket-Extensions");
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..]));
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..]));
assert_eq!(iter.next(), None);
}
#[test]
fn headers_incomplete() {
const data: &'static [u8] =
b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n";
let mut inp = Cursor::new(data);
let hdr = Headers::parse(&mut inp).unwrap();
assert!(hdr.is_none());
}
}

View File

@ -1,33 +1,20 @@
use std::io::{Cursor, Read, Write};
use bytes::Buf;
use httparse;
use httparse::Status;
use input_buffer::{InputBuffer, MIN_READ};
//use input_buffer::{InputBuffer, MIN_READ};
use error::{Error, Result};
use protocol::{WebSocket, Role};
use super::{
Handshake,
HandshakeResult,
Headers,
Httparse,
FromHttparse,
convert_key,
MAX_HEADERS
};
use util::NonBlockingResult;
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
/// Request from the client.
pub struct Request {
path: String,
headers: Headers,
pub path: String,
pub headers: Headers,
}
impl Request {
/// Parse the request from a stream.
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> {
Request::parse_http(input)
}
/// Reply to the response.
pub fn reply(&self) -> Result<Vec<u8>> {
let key = self.headers.find_first("Sec-WebSocket-Key")
@ -42,8 +29,8 @@ impl Request {
}
}
impl Httparse for Request {
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
impl TryParse for Request {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Request::new(&mut hbuffer);
Ok(match req.parse(buf)? {
@ -68,64 +55,40 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
}
}
/// Server handshake
pub struct ServerHandshake<Stream> {
stream: Stream,
state: HandshakeState,
}
/// Server handshake role.
#[allow(missing_copy_implementations)]
pub struct ServerHandshake;
impl<Stream: Read + Write> ServerHandshake<Stream> {
/// Start a new server handshake on top of given stream.
pub fn new(stream: Stream) -> Self {
ServerHandshake {
stream: stream,
state: HandshakeState::ReceivingRequest(InputBuffer::with_capacity(MIN_READ)),
impl ServerHandshake {
/// Start server handshake.
pub fn start<Stream>(stream: Stream) -> MidHandshake<Stream, Self> {
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake,
}
}
}
impl<Stream: Read + Write> Handshake for ServerHandshake<Stream> {
type Stream = WebSocket<Stream>;
fn handshake(mut self) -> Result<HandshakeResult<Self>> {
debug!("Performing server handshake...");
match self.state {
HandshakeState::ReceivingRequest(mut req_buf) => {
req_buf.reserve(MIN_READ, usize::max_value())
.map_err(|_| Error::Capacity("Header too long".into()))?;
req_buf.read_from(&mut self.stream).no_block()?;
let state = if let Some(req) = Request::parse(&mut req_buf)? {
let resp = req.reply()?;
HandshakeState::SendingResponse(Cursor::new(resp))
} else {
HandshakeState::ReceivingRequest(req_buf)
};
Ok(HandshakeResult::Incomplete(ServerHandshake {
state: state,
..self
}))
}
HandshakeState::SendingResponse(mut resp) => {
let size = self.stream.write(Buf::bytes(&resp)).no_block()?.unwrap_or(0);
Buf::advance(&mut resp, size);
if resp.has_remaining() {
Ok(HandshakeResult::Incomplete(ServerHandshake {
state: HandshakeState::SendingResponse(resp),
..self
}))
} else {
let ws = WebSocket::from_raw_socket(self.stream, Role::Server);
Ok(HandshakeResult::Done(ws))
impl HandshakeRole for ServerHandshake {
type IncomingData = Request;
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
-> Result<ProcessingResult<Stream>>
{
Ok(match finish {
StageResult::DoneReading { stream, result, tail } => {
if ! tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()))
}
let response = result.reply()?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}
}
StageResult::DoneWriting(stream) => {
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server))
}
})
}
}
enum HandshakeState {
ReceivingRequest(InputBuffer),
SendingResponse(Cursor<Vec<u8>>),
}
#[cfg(test)]
mod tests {

View File

@ -1,14 +0,0 @@
use native_tls;
use stream::Stream;
use super::{Handshake, HandshakeResult};
pub struct TlsHandshake {
}
impl Handshale for TlsHandshake {
type Stream = Stream;
fn handshake(self) -> Result<HandshakeResult<Self>> {
}
}

View File

@ -1,8 +1,13 @@
use std::net::TcpStream;
pub use handshake::server::ServerHandshake;
use handshake::server::ServerHandshake;
use handshake::HandshakeError;
use protocol::WebSocket;
use std::io::{Read, Write};
/// Accept the given TcpStream as a WebSocket.
pub fn accept(stream: TcpStream) -> ServerHandshake<TcpStream> {
ServerHandshake::new(stream)
pub fn accept<Stream: Read + Write>(stream: Stream)
-> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>>
{
ServerHandshake::start(stream).handshake()
}

View File

@ -1,38 +1,37 @@
#[cfg(feature="tls")]
use native_tls::TlsStream;
use std::net::TcpStream;
use std::io::{Read, Write, Result as IoResult};
/// Stream, either plain TCP or TLS.
pub enum Stream {
Plain(TcpStream),
#[cfg(feature="tls")]
Tls(TlsStream<TcpStream>),
/// Stream mode, either plain TCP or TLS.
#[derive(Clone, Copy)]
pub enum Mode {
Plain,
Tls,
}
impl Read for Stream {
/// Stream, either plain TCP or TLS.
pub enum Stream<S, T> {
Plain(S),
Tls(T),
}
impl<S: Read, T: Read> Read for Stream<S, T> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.read(buf),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.read(buf),
}
}
}
impl Write for Stream {
impl<S: Write, T: Write> Write for Stream<S, T> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match *self {
Stream::Plain(ref mut s) => s.write(buf),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
Stream::Plain(ref mut s) => s.flush(),
#[cfg(feature="tls")]
Stream::Tls(ref mut s) => s.flush(),
}
}