refactor: make handshake completely async
This commit is contained in:
parent
334ceab2b0
commit
b7557f1baa
|
@ -5,9 +5,7 @@ extern crate url;
|
|||
|
||||
use url::Url;
|
||||
|
||||
use tungstenite::protocol::Message;
|
||||
use tungstenite::client::connect;
|
||||
use tungstenite::handshake::Handshake;
|
||||
use tungstenite::error::{Error, Result};
|
||||
|
||||
const AGENT: &'static str = "Tungstenite";
|
||||
|
@ -15,17 +13,17 @@ const AGENT: &'static str = "Tungstenite";
|
|||
fn get_case_count() -> Result<u32> {
|
||||
let mut socket = connect(
|
||||
Url::parse("ws://localhost:9001/getCaseCount").unwrap()
|
||||
)?.handshake_wait()?;
|
||||
)?;
|
||||
let msg = socket.read_message()?;
|
||||
socket.close();
|
||||
socket.close()?;
|
||||
Ok(msg.into_text()?.parse::<u32>().unwrap())
|
||||
}
|
||||
|
||||
fn update_reports() -> Result<()> {
|
||||
let mut socket = connect(
|
||||
Url::parse(&format!("ws://localhost:9001/updateReports?agent={}", AGENT)).unwrap()
|
||||
)?.handshake_wait()?;
|
||||
socket.close();
|
||||
)?;
|
||||
socket.close()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -34,13 +32,11 @@ fn run_test(case: u32) -> Result<()> {
|
|||
let case_url = Url::parse(
|
||||
&format!("ws://localhost:9001/runCase?case={}&agent={}", case, AGENT)
|
||||
).unwrap();
|
||||
let mut socket = connect(case_url)?.handshake_wait()?;
|
||||
let mut socket = connect(case_url)?;
|
||||
loop {
|
||||
let msg = socket.read_message()?;
|
||||
socket.write_message(msg)?;
|
||||
}
|
||||
socket.close();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
|
|
@ -6,11 +6,18 @@ use std::net::{TcpListener, TcpStream};
|
|||
use std::thread::spawn;
|
||||
|
||||
use tungstenite::server::accept;
|
||||
use tungstenite::error::Result;
|
||||
use tungstenite::handshake::Handshake;
|
||||
use tungstenite::handshake::HandshakeError;
|
||||
use tungstenite::error::{Error, Result};
|
||||
|
||||
fn must_not_block<Stream, Role>(err: HandshakeError<Stream, Role>) -> Error {
|
||||
match err {
|
||||
HandshakeError::Interrupted(_) => panic!("Bug: blocking socket would block"),
|
||||
HandshakeError::Failure(f) => f,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_client(stream: TcpStream) -> Result<()> {
|
||||
let mut socket = accept(stream).handshake_wait()?;
|
||||
let mut socket = accept(stream).map_err(must_not_block)?;
|
||||
loop {
|
||||
let msg = socket.read_message()?;
|
||||
socket.write_message(msg)?;
|
||||
|
|
|
@ -5,15 +5,12 @@ extern crate env_logger;
|
|||
use url::Url;
|
||||
use tungstenite::protocol::Message;
|
||||
use tungstenite::client::connect;
|
||||
use tungstenite::handshake::Handshake;
|
||||
|
||||
fn main() {
|
||||
env_logger::init().unwrap();
|
||||
|
||||
let mut socket = connect(Url::parse("ws://localhost:3012/socket").unwrap())
|
||||
.expect("Can't connect")
|
||||
.handshake_wait()
|
||||
.expect("Handshake error");
|
||||
.expect("Can't connect");
|
||||
|
||||
socket.write_message(Message::Text("Hello WebSocket".into())).unwrap();
|
||||
loop {
|
||||
|
|
148
src/client.rs
148
src/client.rs
|
@ -1,75 +1,97 @@
|
|||
use std::net::{TcpStream, ToSocketAddrs};
|
||||
use url::{Url, SocketAddrs};
|
||||
use std::net::{TcpStream, SocketAddr, ToSocketAddrs};
|
||||
use std::result::Result as StdResult;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
use url::Url;
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
use native_tls::{TlsStream, TlsConnector, HandshakeError as TlsHandshakeError};
|
||||
|
||||
use protocol::WebSocket;
|
||||
use handshake::{Handshake as HandshakeTrait, HandshakeResult};
|
||||
use handshake::HandshakeError;
|
||||
use handshake::client::{ClientHandshake, Request};
|
||||
use stream::Mode;
|
||||
use error::{Error, Result};
|
||||
|
||||
/// Connect to the given WebSocket.
|
||||
#[cfg(feature="tls")]
|
||||
use stream::Stream as StreamSwitcher;
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
pub type AutoStream = StreamSwitcher<TcpStream, TlsStream<TcpStream>>;
|
||||
#[cfg(not(feature="tls"))]
|
||||
pub type AutoStream = TcpStream;
|
||||
|
||||
/// Connect to the given WebSocket in blocking mode.
|
||||
///
|
||||
/// Note that this function may block the current thread while DNS resolution is performed.
|
||||
pub fn connect(url: Url) -> Result<Handshake> {
|
||||
let mode = match url.scheme() {
|
||||
"ws" => Mode::Plain,
|
||||
#[cfg(feature="tls")]
|
||||
"wss" => Mode::Tls,
|
||||
_ => return Err(Error::Url("URL scheme not supported".into()))
|
||||
};
|
||||
|
||||
// Note that this function may block the current thread while resolution is performed.
|
||||
/// The URL may be either ws:// or wss://.
|
||||
/// To support wss:// URLs, feature "tls" must be turned on.
|
||||
pub fn connect(url: Url) -> Result<WebSocket<AutoStream>> {
|
||||
let mode = url_mode(&url)?;
|
||||
let addrs = url.to_socket_addrs()?;
|
||||
Ok(Handshake {
|
||||
state: HandshakeState::Nothing(url),
|
||||
alt_addresses: addrs,
|
||||
})
|
||||
let stream = connect_to_some(addrs, &url, mode)?;
|
||||
client(url.clone(), stream)
|
||||
.map_err(|e| match e {
|
||||
HandshakeError::Failure(f) => f,
|
||||
HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
|
||||
})
|
||||
}
|
||||
|
||||
enum Mode {
|
||||
Plain,
|
||||
Tls,
|
||||
}
|
||||
|
||||
enum HandshakeState {
|
||||
Nothing(Url),
|
||||
WebSocket(ClientHandshake<TcpStream>),
|
||||
}
|
||||
|
||||
pub struct Handshake {
|
||||
state: HandshakeState,
|
||||
alt_addresses: SocketAddrs,
|
||||
}
|
||||
|
||||
impl HandshakeTrait for Handshake {
|
||||
type Stream = WebSocket<TcpStream>;
|
||||
fn handshake(mut self) -> Result<HandshakeResult<Self>> {
|
||||
match self.state {
|
||||
HandshakeState::Nothing(url) => {
|
||||
if let Some(addr) = self.alt_addresses.next() {
|
||||
debug!("Trying to contact {} at {}...", url, addr);
|
||||
let state = {
|
||||
if let Ok(stream) = TcpStream::connect(addr) {
|
||||
let hs = ClientHandshake::new(stream, Request { url: url });
|
||||
HandshakeState::WebSocket(hs)
|
||||
} else {
|
||||
HandshakeState::Nothing(url)
|
||||
}
|
||||
};
|
||||
Ok(HandshakeResult::Incomplete(Handshake {
|
||||
state: state,
|
||||
..self
|
||||
}))
|
||||
} else {
|
||||
Err(Error::Url(format!("Unable to resolve {}", url).into()))
|
||||
}
|
||||
}
|
||||
HandshakeState::WebSocket(ws) => {
|
||||
let alt_addresses = self.alt_addresses;
|
||||
ws.handshake().map(move |r| r.map(move |s| Handshake {
|
||||
state: HandshakeState::WebSocket(s),
|
||||
alt_addresses: alt_addresses,
|
||||
}))
|
||||
}
|
||||
#[cfg(feature="tls")]
|
||||
fn wrap_stream(stream: TcpStream, domain: &str, mode: Mode) -> Result<AutoStream> {
|
||||
match mode {
|
||||
Mode::Plain => Ok(StreamSwitcher::Plain(stream)),
|
||||
Mode::Tls => {
|
||||
let connector = TlsConnector::builder()?.build()?;
|
||||
connector.connect(domain, stream)
|
||||
.map_err(|e| match e {
|
||||
TlsHandshakeError::Failure(f) => f.into(),
|
||||
TlsHandshakeError::Interrupted(_) => panic!("Bug: TLS handshake not blocked"),
|
||||
})
|
||||
.map(|s| StreamSwitcher::Tls(s))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature="tls"))]
|
||||
fn wrap_stream(stream: TcpStream, _domain: &str, mode: Mode) -> Result<AutoStream> {
|
||||
match mode {
|
||||
Mode::Plain => Ok(stream),
|
||||
Mode::Tls => Err(Error::Url("TLS support not compiled in.".into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn connect_to_some<A>(addrs: A, url: &Url, mode: Mode) -> Result<AutoStream>
|
||||
where A: Iterator<Item=SocketAddr>
|
||||
{
|
||||
let domain = url.host_str().ok_or(Error::Url("No host name in the URL".into()))?;
|
||||
for addr in addrs {
|
||||
debug!("Trying to contact {} at {}...", url, addr);
|
||||
if let Ok(raw_stream) = TcpStream::connect(addr) {
|
||||
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
|
||||
return Ok(stream)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(Error::Url(format!("Unable to connect to {}", url).into()))
|
||||
}
|
||||
|
||||
/// Get the mode of the given URL.
|
||||
///
|
||||
/// This function may be used in non-blocking implementations.
|
||||
pub fn url_mode(url: &Url) -> Result<Mode> {
|
||||
match url.scheme() {
|
||||
"ws" => Ok(Mode::Plain),
|
||||
"wss" => Ok(Mode::Tls),
|
||||
_ => Err(Error::Url("URL scheme not supported".into()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Do the client handshake over the given stream.
|
||||
///
|
||||
/// Use this function if you need a nonblocking handshake support.
|
||||
pub fn client<Stream: Read + Write>(url: Url, stream: Stream)
|
||||
-> StdResult<WebSocket<Stream>, HandshakeError<Stream, ClientHandshake>>
|
||||
{
|
||||
let request = Request { url: url };
|
||||
ClientHandshake::start(stream, request).handshake()
|
||||
}
|
||||
|
|
17
src/error.rs
17
src/error.rs
|
@ -11,6 +11,9 @@ use std::string;
|
|||
|
||||
use httparse;
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
use native_tls;
|
||||
|
||||
pub type Result<T> = result::Result<T, Error>;
|
||||
|
||||
/// Possible WebSocket errors
|
||||
|
@ -20,6 +23,9 @@ pub enum Error {
|
|||
ConnectionClosed,
|
||||
/// Input-output error
|
||||
Io(io::Error),
|
||||
#[cfg(feature="tls")]
|
||||
/// TLS error
|
||||
Tls(native_tls::Error),
|
||||
/// Buffer capacity exhausted
|
||||
Capacity(Cow<'static, str>),
|
||||
/// Protocol violation
|
||||
|
@ -37,6 +43,8 @@ impl fmt::Display for Error {
|
|||
match *self {
|
||||
Error::ConnectionClosed => write!(f, "Connection closed"),
|
||||
Error::Io(ref err) => write!(f, "IO error: {}", err),
|
||||
#[cfg(feature="tls")]
|
||||
Error::Tls(ref err) => write!(f, "TLS error: {}", err),
|
||||
Error::Capacity(ref msg) => write!(f, "Space limit exceeded: {}", msg),
|
||||
Error::Protocol(ref msg) => write!(f, "WebSocket protocol error: {}", msg),
|
||||
Error::Utf8 => write!(f, "UTF-8 encoding error"),
|
||||
|
@ -51,6 +59,8 @@ impl ErrorTrait for Error {
|
|||
match *self {
|
||||
Error::ConnectionClosed => "",
|
||||
Error::Io(ref err) => err.description(),
|
||||
#[cfg(feature="tls")]
|
||||
Error::Tls(ref err) => err.description(),
|
||||
Error::Capacity(ref msg) => msg.borrow(),
|
||||
Error::Protocol(ref msg) => msg.borrow(),
|
||||
Error::Utf8 => "",
|
||||
|
@ -78,6 +88,13 @@ impl From<string::FromUtf8Error> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature="tls")]
|
||||
impl From<native_tls::Error> for Error {
|
||||
fn from(err: native_tls::Error) -> Self {
|
||||
Error::Tls(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<httparse::Error> for Error {
|
||||
fn from(err: httparse::Error) -> Self {
|
||||
match err {
|
||||
|
|
|
@ -1,25 +1,16 @@
|
|||
use std::io::{Read, Write, Cursor};
|
||||
|
||||
use base64;
|
||||
use rand;
|
||||
use bytes::Buf;
|
||||
use httparse;
|
||||
use httparse::Status;
|
||||
use std::io::Write;
|
||||
use url::Url;
|
||||
|
||||
use input_buffer::{InputBuffer, MIN_READ};
|
||||
use error::{Error, Result};
|
||||
use protocol::{
|
||||
WebSocket, Role,
|
||||
};
|
||||
use super::{
|
||||
Headers,
|
||||
Httparse, FromHttparse,
|
||||
Handshake, HandshakeResult,
|
||||
convert_key,
|
||||
MAX_HEADERS,
|
||||
};
|
||||
use util::NonBlockingResult;
|
||||
use protocol::{WebSocket, Role};
|
||||
|
||||
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
|
||||
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
|
||||
use super::machine::{HandshakeMachine, StageResult, TryParse};
|
||||
|
||||
/// Client request.
|
||||
pub struct Request {
|
||||
|
@ -47,77 +38,59 @@ impl Request {
|
|||
}
|
||||
}
|
||||
|
||||
/// Client handshake.
|
||||
pub struct ClientHandshake<Stream> {
|
||||
stream: Stream,
|
||||
state: HandshakeState,
|
||||
/// Client handshake role.
|
||||
pub struct ClientHandshake {
|
||||
verify_data: VerifyData,
|
||||
}
|
||||
|
||||
impl<Stream: Read + Write> ClientHandshake<Stream> {
|
||||
/// Initiate a WebSocket handshake over the given stream.
|
||||
pub fn new(stream: Stream, request: Request) -> Self {
|
||||
impl ClientHandshake {
|
||||
/// Initiate a client handshake.
|
||||
pub fn start<Stream>(stream: Stream, request: Request) -> MidHandshake<Stream, Self> {
|
||||
let key = generate_key();
|
||||
|
||||
let mut req = Vec::new();
|
||||
write!(req, "\
|
||||
GET {path} HTTP/1.1\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: {key}\r\n\
|
||||
\r\n", host = request.get_host(), path = request.get_path(), key = key)
|
||||
.unwrap();
|
||||
let machine = {
|
||||
let mut req = Vec::new();
|
||||
write!(req, "\
|
||||
GET {path} HTTP/1.1\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: {key}\r\n\
|
||||
\r\n", host = request.get_host(), path = request.get_path(), key = key)
|
||||
.unwrap();
|
||||
HandshakeMachine::start_write(stream, req)
|
||||
};
|
||||
|
||||
let accept_key = convert_key(key.as_ref()).unwrap();
|
||||
let client = {
|
||||
let accept_key = convert_key(key.as_ref()).unwrap();
|
||||
ClientHandshake {
|
||||
verify_data: VerifyData {
|
||||
accept_key: accept_key,
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
ClientHandshake {
|
||||
stream: stream,
|
||||
state: HandshakeState::SendingRequest(Cursor::new(req)),
|
||||
verify_data: VerifyData {
|
||||
accept_key: accept_key,
|
||||
},
|
||||
}
|
||||
debug!("Client handshake initiated.");
|
||||
MidHandshake { role: client, machine: machine }
|
||||
}
|
||||
}
|
||||
|
||||
impl<Stream: Read + Write> Handshake for ClientHandshake<Stream> {
|
||||
type Stream = WebSocket<Stream>;
|
||||
fn handshake(mut self) -> Result<HandshakeResult<Self>> {
|
||||
debug!("Performing client handshake...");
|
||||
match self.state {
|
||||
HandshakeState::SendingRequest(mut req) => {
|
||||
let size = self.stream.write(Buf::bytes(&req)).no_block()?.unwrap_or(0);
|
||||
Buf::advance(&mut req, size);
|
||||
let state = if req.has_remaining() {
|
||||
HandshakeState::SendingRequest(req)
|
||||
} else {
|
||||
HandshakeState::ReceivingResponse(InputBuffer::with_capacity(MIN_READ))
|
||||
};
|
||||
Ok(HandshakeResult::Incomplete(ClientHandshake {
|
||||
state: state,
|
||||
..self
|
||||
}))
|
||||
impl HandshakeRole for ClientHandshake {
|
||||
type IncomingData = Response;
|
||||
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
|
||||
-> Result<ProcessingResult<Stream>>
|
||||
{
|
||||
Ok(match finish {
|
||||
StageResult::DoneWriting(stream) => {
|
||||
ProcessingResult::Continue(HandshakeMachine::start_read(stream))
|
||||
}
|
||||
HandshakeState::ReceivingResponse(mut resp_buf) => {
|
||||
resp_buf.reserve(MIN_READ, usize::max_value())
|
||||
.map_err(|_| Error::Capacity("Header too long".into()))?;
|
||||
resp_buf.read_from(&mut self.stream).no_block()?;
|
||||
if let Some(resp) = Response::parse(&mut resp_buf)? {
|
||||
self.verify_data.verify_response(&resp)?;
|
||||
let ws = WebSocket::from_partially_read(self.stream,
|
||||
resp_buf.into_vec(), Role::Client);
|
||||
debug!("Client handshake done.");
|
||||
Ok(HandshakeResult::Done(ws))
|
||||
} else {
|
||||
Ok(HandshakeResult::Incomplete(ClientHandshake {
|
||||
state: HandshakeState::ReceivingResponse(resp_buf),
|
||||
..self
|
||||
}))
|
||||
}
|
||||
StageResult::DoneReading { stream, result, tail, } => {
|
||||
self.verify_data.verify_response(&result)?;
|
||||
debug!("Client handshake done.");
|
||||
ProcessingResult::Done(WebSocket::from_partially_read(stream, tail, Role::Client))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -173,27 +146,14 @@ impl VerifyData {
|
|||
}
|
||||
}
|
||||
|
||||
/// Internal state of the client handshake.
|
||||
enum HandshakeState {
|
||||
SendingRequest(Cursor<Vec<u8>>),
|
||||
ReceivingResponse(InputBuffer),
|
||||
}
|
||||
|
||||
/// Server response.
|
||||
pub struct Response {
|
||||
code: u16,
|
||||
headers: Headers,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
/// Parse the response from a stream.
|
||||
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> {
|
||||
Response::parse_http(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl Httparse for Response {
|
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
impl TryParse for Response {
|
||||
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||
let mut req = httparse::Response::new(&mut hbuffer);
|
||||
Ok(match req.parse(buf)? {
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
use std::ascii::AsciiExt;
|
||||
use std::str::from_utf8;
|
||||
use std::slice;
|
||||
|
||||
use httparse;
|
||||
|
||||
use error::Result;
|
||||
|
||||
// Limit the number of header lines.
|
||||
pub const MAX_HEADERS: usize = 124;
|
||||
|
||||
/// HTTP request or response headers.
|
||||
#[derive(Debug)]
|
||||
pub struct Headers {
|
||||
data: Vec<(String, Box<[u8]>)>,
|
||||
}
|
||||
|
||||
impl Headers {
|
||||
|
||||
/// Get first header with the given name, if any.
|
||||
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
|
||||
self.find(name).next()
|
||||
}
|
||||
|
||||
/// Iterate over all headers with the given name.
|
||||
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
|
||||
HeadersIter {
|
||||
name: name,
|
||||
iter: self.data.iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the given header has the given value.
|
||||
pub fn header_is(&self, name: &str, value: &str) -> bool {
|
||||
self.find_first(name)
|
||||
.map(|v| v == value.as_bytes())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Check if the given header has the given value (case-insensitive).
|
||||
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
|
||||
self.find_first(name).ok_or(())
|
||||
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
|
||||
.map(|val| val.eq_ignore_ascii_case(value))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// The iterator over headers.
|
||||
pub struct HeadersIter<'name, 'headers> {
|
||||
name: &'name str,
|
||||
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
|
||||
}
|
||||
|
||||
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
|
||||
type Item = &'headers [u8];
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(&(ref name, ref value)) = self.iter.next() {
|
||||
if name.eq_ignore_ascii_case(self.name) {
|
||||
return Some(value)
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Trait to convert raw objects into HTTP parseables.
|
||||
pub trait FromHttparse<T>: Sized {
|
||||
fn from_httparse(raw: T) -> Result<Self>;
|
||||
}
|
||||
|
||||
/*
|
||||
impl TryParse for Headers {
|
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||
Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
|
||||
Status::Partial => None,
|
||||
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
|
||||
})
|
||||
}
|
||||
}*/
|
||||
|
||||
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
|
||||
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
|
||||
Ok(Headers {
|
||||
data: raw.iter()
|
||||
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::Headers;
|
||||
|
||||
use std::io::Cursor;
|
||||
|
||||
#[test]
|
||||
fn headers() {
|
||||
const data: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
\r\n";
|
||||
let mut inp = Cursor::new(data);
|
||||
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
|
||||
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
|
||||
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
|
||||
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));
|
||||
|
||||
assert!(hdr.header_is("upgrade", "websocket"));
|
||||
assert!(!hdr.header_is("upgrade", "Websocket"));
|
||||
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn headers_iter() {
|
||||
const data: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
Sec-WebSocket-Extensions: permessage-deflate\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
\r\n";
|
||||
let mut inp = Cursor::new(data);
|
||||
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
|
||||
let mut iter = hdr.find("Sec-WebSocket-Extensions");
|
||||
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..]));
|
||||
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..]));
|
||||
assert_eq!(iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn headers_incomplete() {
|
||||
const data: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n";
|
||||
let mut inp = Cursor::new(data);
|
||||
let hdr = Headers::parse(&mut inp).unwrap();
|
||||
assert!(hdr.is_none());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
use std::io::{Cursor, Read, Write};
|
||||
use bytes::Buf;
|
||||
|
||||
use input_buffer::{InputBuffer, MIN_READ};
|
||||
use error::{Error, Result};
|
||||
use util::NonBlockingResult;
|
||||
|
||||
/// A generic handshake state machine.
|
||||
pub struct HandshakeMachine<Stream> {
|
||||
stream: Stream,
|
||||
state: HandshakeState,
|
||||
}
|
||||
|
||||
impl<Stream> HandshakeMachine<Stream> {
|
||||
/// Start reading data from the peer.
|
||||
pub fn start_read(stream: Stream) -> Self {
|
||||
HandshakeMachine {
|
||||
stream: stream,
|
||||
state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)),
|
||||
}
|
||||
}
|
||||
/// Start writing data to the peer.
|
||||
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
|
||||
HandshakeMachine {
|
||||
stream: stream,
|
||||
state: HandshakeState::Writing(Cursor::new(data.into())),
|
||||
}
|
||||
}
|
||||
/// Returns a shared reference to the inner stream.
|
||||
pub fn get_ref(&self) -> &Stream {
|
||||
&self.stream
|
||||
}
|
||||
/// Returns a mutable reference to the inner stream.
|
||||
pub fn get_mut(&mut self) -> &mut Stream {
|
||||
&mut self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<Stream: Read + Write> HandshakeMachine<Stream> {
|
||||
/// Perform a single handshake round.
|
||||
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
|
||||
Ok(match self.state {
|
||||
HandshakeState::Reading(mut buf) => {
|
||||
buf.reserve(MIN_READ, usize::max_value()) // TODO limit size
|
||||
.map_err(|_| Error::Capacity("Header too long".into()))?;
|
||||
if let Some(_) = buf.read_from(&mut self.stream).no_block()? {
|
||||
if let Some((size, obj)) = Obj::try_parse(Buf::bytes(&buf))? {
|
||||
buf.advance(size);
|
||||
RoundResult::StageFinished(StageResult::DoneReading {
|
||||
result: obj,
|
||||
stream: self.stream,
|
||||
tail: buf.into_vec(),
|
||||
})
|
||||
} else {
|
||||
RoundResult::Incomplete(HandshakeMachine {
|
||||
state: HandshakeState::Reading(buf),
|
||||
..self
|
||||
})
|
||||
}
|
||||
} else {
|
||||
RoundResult::WouldBlock(HandshakeMachine {
|
||||
state: HandshakeState::Reading(buf),
|
||||
..self
|
||||
})
|
||||
}
|
||||
}
|
||||
HandshakeState::Writing(mut buf) => {
|
||||
if let Some(size) = self.stream.write(Buf::bytes(&buf)).no_block()? {
|
||||
buf.advance(size);
|
||||
if buf.has_remaining() {
|
||||
RoundResult::Incomplete(HandshakeMachine {
|
||||
state: HandshakeState::Writing(buf),
|
||||
..self
|
||||
})
|
||||
} else {
|
||||
RoundResult::StageFinished(StageResult::DoneWriting(self.stream))
|
||||
}
|
||||
} else {
|
||||
RoundResult::WouldBlock(HandshakeMachine {
|
||||
state: HandshakeState::Writing(buf),
|
||||
..self
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// The result of the round.
|
||||
pub enum RoundResult<Obj, Stream> {
|
||||
/// Round not done, I/O would block.
|
||||
WouldBlock(HandshakeMachine<Stream>),
|
||||
/// Round done, state unchanged.
|
||||
Incomplete(HandshakeMachine<Stream>),
|
||||
/// Stage complete.
|
||||
StageFinished(StageResult<Obj, Stream>),
|
||||
}
|
||||
|
||||
/// The result of the stage.
|
||||
pub enum StageResult<Obj, Stream> {
|
||||
/// Reading round finished.
|
||||
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
|
||||
/// Writing round finished.
|
||||
DoneWriting(Stream),
|
||||
}
|
||||
|
||||
/// The parseable object.
|
||||
pub trait TryParse: Sized {
|
||||
/// Return Ok(None) if incomplete, Err on syntax error.
|
||||
fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
|
||||
}
|
||||
|
||||
/// The handshake state.
|
||||
enum HandshakeState {
|
||||
/// Reading data from the peer.
|
||||
Reading(InputBuffer),
|
||||
/// Sending data to the peer.
|
||||
Writing(Cursor<Vec<u8>>),
|
||||
}
|
|
@ -1,63 +1,89 @@
|
|||
pub mod headers;
|
||||
pub mod client;
|
||||
pub mod server;
|
||||
#[cfg(feature="tls")]
|
||||
pub mod tls;
|
||||
|
||||
use std::ascii::AsciiExt;
|
||||
use std::str::from_utf8;
|
||||
use std::slice;
|
||||
mod machine;
|
||||
|
||||
use std::io::{Read, Write};
|
||||
|
||||
use base64;
|
||||
use bytes::Buf;
|
||||
use httparse;
|
||||
use httparse::Status;
|
||||
use sha1::Sha1;
|
||||
|
||||
use error::Result;
|
||||
use error::Error;
|
||||
use protocol::WebSocket;
|
||||
|
||||
// Limit the number of header lines.
|
||||
const MAX_HEADERS: usize = 124;
|
||||
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
|
||||
|
||||
/// A handshake state.
|
||||
pub trait Handshake: Sized {
|
||||
/// Resulting stream of this handshake.
|
||||
type Stream;
|
||||
/// Perform a single handshake round.
|
||||
fn handshake(self) -> Result<HandshakeResult<Self>>;
|
||||
/// Perform handshake to the end in a blocking mode.
|
||||
fn handshake_wait(self) -> Result<Self::Stream> {
|
||||
let mut hs = self;
|
||||
/// A WebSocket handshake.
|
||||
pub struct MidHandshake<Stream, Role> {
|
||||
role: Role,
|
||||
machine: HandshakeMachine<Stream>,
|
||||
}
|
||||
|
||||
impl<Stream, Role> MidHandshake<Stream, Role> {
|
||||
/// Returns a shared reference to the inner stream.
|
||||
pub fn get_ref(&self) -> &Stream {
|
||||
self.machine.get_ref()
|
||||
}
|
||||
/// Returns a mutable reference to the inner stream.
|
||||
pub fn get_mut(&mut self) -> &mut Stream {
|
||||
self.machine.get_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Stream: Read + Write, Role: HandshakeRole> MidHandshake<Stream, Role> {
|
||||
/// Restarts the handshake process.
|
||||
pub fn handshake(self) -> Result<WebSocket<Stream>, HandshakeError<Stream, Role>> {
|
||||
let mut mach = self.machine;
|
||||
loop {
|
||||
hs = match hs.handshake()? {
|
||||
HandshakeResult::Done(stream) => return Ok(stream),
|
||||
HandshakeResult::Incomplete(s) => s,
|
||||
mach = match mach.single_round()? {
|
||||
RoundResult::WouldBlock(m) => {
|
||||
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
|
||||
}
|
||||
RoundResult::Incomplete(m) => m,
|
||||
RoundResult::StageFinished(s) => {
|
||||
match self.role.stage_finished(s)? {
|
||||
ProcessingResult::Continue(m) => m,
|
||||
ProcessingResult::Done(ws) => return Ok(ws),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A handshake result.
|
||||
pub enum HandshakeResult<H: Handshake> {
|
||||
/// Handshake is done, a WebSocket stream is ready.
|
||||
Done(H::Stream),
|
||||
/// Handshake is not done, call handshake() again.
|
||||
Incomplete(H),
|
||||
pub enum HandshakeError<Stream, Role> {
|
||||
/// Handshake was interrupted (would block).
|
||||
Interrupted(MidHandshake<Stream, Role>),
|
||||
/// Handshake failed.
|
||||
Failure(Error),
|
||||
}
|
||||
|
||||
impl<H: Handshake> HandshakeResult<H> {
|
||||
pub fn map<R, F>(self, func: F) -> HandshakeResult<R>
|
||||
where R: Handshake<Stream = H::Stream>,
|
||||
F: FnOnce(H) -> R,
|
||||
{
|
||||
match self {
|
||||
HandshakeResult::Done(s) => HandshakeResult::Done(s),
|
||||
HandshakeResult::Incomplete(h) => HandshakeResult::Incomplete(func(h)),
|
||||
}
|
||||
impl<Stream, Role> From<Error> for HandshakeError<Stream, Role> {
|
||||
fn from(err: Error) -> Self {
|
||||
HandshakeError::Failure(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Handshake role.
|
||||
pub trait HandshakeRole {
|
||||
#[doc(hidden)]
|
||||
type IncomingData: TryParse;
|
||||
#[doc(hidden)]
|
||||
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
|
||||
-> Result<ProcessingResult<Stream>, Error>;
|
||||
}
|
||||
|
||||
/// Stage processing result.
|
||||
#[doc(hidden)]
|
||||
pub enum ProcessingResult<Stream> {
|
||||
Continue(HandshakeMachine<Stream>),
|
||||
Done(WebSocket<Stream>),
|
||||
}
|
||||
|
||||
/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
|
||||
fn convert_key(input: &[u8]) -> Result<String> {
|
||||
fn convert_key(input: &[u8]) -> Result<String, Error> {
|
||||
// ... field is constructed by concatenating /key/ ...
|
||||
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
|
||||
const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
|
@ -67,113 +93,10 @@ fn convert_key(input: &[u8]) -> Result<String> {
|
|||
Ok(base64::encode(&sha1.digest().bytes()))
|
||||
}
|
||||
|
||||
/// HTTP request or response headers.
|
||||
#[derive(Debug)]
|
||||
pub struct Headers {
|
||||
data: Vec<(String, Box<[u8]>)>,
|
||||
}
|
||||
|
||||
impl Headers {
|
||||
|
||||
/// Get first header with the given name, if any.
|
||||
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
|
||||
self.find(name).next()
|
||||
}
|
||||
|
||||
/// Iterate over all headers with the given name.
|
||||
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
|
||||
HeadersIter {
|
||||
name: name,
|
||||
iter: self.data.iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the given header has the given value.
|
||||
pub fn header_is(&self, name: &str, value: &str) -> bool {
|
||||
self.find_first(name)
|
||||
.map(|v| v == value.as_bytes())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Check if the given header has the given value (case-insensitive).
|
||||
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
|
||||
self.find_first(name).ok_or(())
|
||||
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
|
||||
.map(|val| val.eq_ignore_ascii_case(value))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Try to parse data and return headers, if any.
|
||||
fn parse<B: Buf>(input: &mut B) -> Result<Option<Headers>> {
|
||||
Headers::parse_http(input)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// The iterator over headers.
|
||||
pub struct HeadersIter<'name, 'headers> {
|
||||
name: &'name str,
|
||||
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
|
||||
}
|
||||
|
||||
impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
|
||||
type Item = &'headers [u8];
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(&(ref name, ref value)) = self.iter.next() {
|
||||
if name.eq_ignore_ascii_case(self.name) {
|
||||
return Some(value)
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Trait to read HTTP parseable objects.
|
||||
trait Httparse: Sized {
|
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>>;
|
||||
fn parse_http<B: Buf>(input: &mut B) -> Result<Option<Self>> {
|
||||
Ok(match Self::httparse(input.bytes())? {
|
||||
Some((size, obj)) => {
|
||||
input.advance(size);
|
||||
Some(obj)
|
||||
},
|
||||
None => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait to convert raw objects into HTTP parseables.
|
||||
trait FromHttparse<T>: Sized {
|
||||
fn from_httparse(raw: T) -> Result<Self>;
|
||||
}
|
||||
|
||||
impl Httparse for Headers {
|
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||
Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
|
||||
Status::Partial => None,
|
||||
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
|
||||
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
|
||||
Ok(Headers {
|
||||
data: raw.iter()
|
||||
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::{Headers, convert_key};
|
||||
|
||||
use std::io::Cursor;
|
||||
use super::convert_key;
|
||||
|
||||
#[test]
|
||||
fn key_conversion() {
|
||||
|
@ -182,50 +105,4 @@ mod tests {
|
|||
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn headers() {
|
||||
const data: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
\r\n";
|
||||
let mut inp = Cursor::new(data);
|
||||
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
|
||||
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
|
||||
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
|
||||
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));
|
||||
|
||||
assert!(hdr.header_is("upgrade", "websocket"));
|
||||
assert!(!hdr.header_is("upgrade", "Websocket"));
|
||||
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn headers_iter() {
|
||||
const data: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
Sec-WebSocket-Extensions: permessage-deflate\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
\r\n";
|
||||
let mut inp = Cursor::new(data);
|
||||
let hdr = Headers::parse(&mut inp).unwrap().unwrap();
|
||||
let mut iter = hdr.find("Sec-WebSocket-Extensions");
|
||||
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..]));
|
||||
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..]));
|
||||
assert_eq!(iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn headers_incomplete() {
|
||||
const data: &'static [u8] =
|
||||
b"Host: foo.com\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n";
|
||||
let mut inp = Cursor::new(data);
|
||||
let hdr = Headers::parse(&mut inp).unwrap();
|
||||
assert!(hdr.is_none());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,33 +1,20 @@
|
|||
use std::io::{Cursor, Read, Write};
|
||||
use bytes::Buf;
|
||||
use httparse;
|
||||
use httparse::Status;
|
||||
|
||||
use input_buffer::{InputBuffer, MIN_READ};
|
||||
//use input_buffer::{InputBuffer, MIN_READ};
|
||||
use error::{Error, Result};
|
||||
use protocol::{WebSocket, Role};
|
||||
use super::{
|
||||
Handshake,
|
||||
HandshakeResult,
|
||||
Headers,
|
||||
Httparse,
|
||||
FromHttparse,
|
||||
convert_key,
|
||||
MAX_HEADERS
|
||||
};
|
||||
use util::NonBlockingResult;
|
||||
use super::headers::{Headers, FromHttparse, MAX_HEADERS};
|
||||
use super::machine::{HandshakeMachine, StageResult, TryParse};
|
||||
use super::{MidHandshake, HandshakeRole, ProcessingResult, convert_key};
|
||||
|
||||
/// Request from the client.
|
||||
pub struct Request {
|
||||
path: String,
|
||||
headers: Headers,
|
||||
pub path: String,
|
||||
pub headers: Headers,
|
||||
}
|
||||
|
||||
impl Request {
|
||||
/// Parse the request from a stream.
|
||||
pub fn parse<B: Buf>(input: &mut B) -> Result<Option<Self>> {
|
||||
Request::parse_http(input)
|
||||
}
|
||||
/// Reply to the response.
|
||||
pub fn reply(&self) -> Result<Vec<u8>> {
|
||||
let key = self.headers.find_first("Sec-WebSocket-Key")
|
||||
|
@ -42,8 +29,8 @@ impl Request {
|
|||
}
|
||||
}
|
||||
|
||||
impl Httparse for Request {
|
||||
fn httparse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
impl TryParse for Request {
|
||||
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||
let mut req = httparse::Request::new(&mut hbuffer);
|
||||
Ok(match req.parse(buf)? {
|
||||
|
@ -68,64 +55,40 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
|
|||
}
|
||||
}
|
||||
|
||||
/// Server handshake
|
||||
pub struct ServerHandshake<Stream> {
|
||||
stream: Stream,
|
||||
state: HandshakeState,
|
||||
}
|
||||
/// Server handshake role.
|
||||
#[allow(missing_copy_implementations)]
|
||||
pub struct ServerHandshake;
|
||||
|
||||
impl<Stream: Read + Write> ServerHandshake<Stream> {
|
||||
/// Start a new server handshake on top of given stream.
|
||||
pub fn new(stream: Stream) -> Self {
|
||||
ServerHandshake {
|
||||
stream: stream,
|
||||
state: HandshakeState::ReceivingRequest(InputBuffer::with_capacity(MIN_READ)),
|
||||
impl ServerHandshake {
|
||||
/// Start server handshake.
|
||||
pub fn start<Stream>(stream: Stream) -> MidHandshake<Stream, Self> {
|
||||
MidHandshake {
|
||||
machine: HandshakeMachine::start_read(stream),
|
||||
role: ServerHandshake,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Stream: Read + Write> Handshake for ServerHandshake<Stream> {
|
||||
type Stream = WebSocket<Stream>;
|
||||
fn handshake(mut self) -> Result<HandshakeResult<Self>> {
|
||||
debug!("Performing server handshake...");
|
||||
match self.state {
|
||||
HandshakeState::ReceivingRequest(mut req_buf) => {
|
||||
req_buf.reserve(MIN_READ, usize::max_value())
|
||||
.map_err(|_| Error::Capacity("Header too long".into()))?;
|
||||
req_buf.read_from(&mut self.stream).no_block()?;
|
||||
let state = if let Some(req) = Request::parse(&mut req_buf)? {
|
||||
let resp = req.reply()?;
|
||||
HandshakeState::SendingResponse(Cursor::new(resp))
|
||||
} else {
|
||||
HandshakeState::ReceivingRequest(req_buf)
|
||||
};
|
||||
Ok(HandshakeResult::Incomplete(ServerHandshake {
|
||||
state: state,
|
||||
..self
|
||||
}))
|
||||
}
|
||||
HandshakeState::SendingResponse(mut resp) => {
|
||||
let size = self.stream.write(Buf::bytes(&resp)).no_block()?.unwrap_or(0);
|
||||
Buf::advance(&mut resp, size);
|
||||
if resp.has_remaining() {
|
||||
Ok(HandshakeResult::Incomplete(ServerHandshake {
|
||||
state: HandshakeState::SendingResponse(resp),
|
||||
..self
|
||||
}))
|
||||
} else {
|
||||
let ws = WebSocket::from_raw_socket(self.stream, Role::Server);
|
||||
Ok(HandshakeResult::Done(ws))
|
||||
impl HandshakeRole for ServerHandshake {
|
||||
type IncomingData = Request;
|
||||
fn stage_finished<Stream>(&self, finish: StageResult<Self::IncomingData, Stream>)
|
||||
-> Result<ProcessingResult<Stream>>
|
||||
{
|
||||
Ok(match finish {
|
||||
StageResult::DoneReading { stream, result, tail } => {
|
||||
if ! tail.is_empty() {
|
||||
return Err(Error::Protocol("Junk after client request".into()))
|
||||
}
|
||||
let response = result.reply()?;
|
||||
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
|
||||
}
|
||||
}
|
||||
StageResult::DoneWriting(stream) => {
|
||||
ProcessingResult::Done(WebSocket::from_raw_socket(stream, Role::Server))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum HandshakeState {
|
||||
ReceivingRequest(InputBuffer),
|
||||
SendingResponse(Cursor<Vec<u8>>),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
use native_tls;
|
||||
|
||||
use stream::Stream;
|
||||
use super::{Handshake, HandshakeResult};
|
||||
|
||||
pub struct TlsHandshake {
|
||||
|
||||
}
|
||||
|
||||
impl Handshale for TlsHandshake {
|
||||
type Stream = Stream;
|
||||
fn handshake(self) -> Result<HandshakeResult<Self>> {
|
||||
}
|
||||
}
|
|
@ -1,8 +1,13 @@
|
|||
use std::net::TcpStream;
|
||||
pub use handshake::server::ServerHandshake;
|
||||
|
||||
use handshake::server::ServerHandshake;
|
||||
use handshake::HandshakeError;
|
||||
use protocol::WebSocket;
|
||||
|
||||
use std::io::{Read, Write};
|
||||
|
||||
/// Accept the given TcpStream as a WebSocket.
|
||||
pub fn accept(stream: TcpStream) -> ServerHandshake<TcpStream> {
|
||||
ServerHandshake::new(stream)
|
||||
pub fn accept<Stream: Read + Write>(stream: Stream)
|
||||
-> Result<WebSocket<Stream>, HandshakeError<Stream, ServerHandshake>>
|
||||
{
|
||||
ServerHandshake::start(stream).handshake()
|
||||
}
|
||||
|
|
|
@ -1,38 +1,37 @@
|
|||
#[cfg(feature="tls")]
|
||||
use native_tls::TlsStream;
|
||||
|
||||
use std::net::TcpStream;
|
||||
use std::io::{Read, Write, Result as IoResult};
|
||||
|
||||
/// Stream, either plain TCP or TLS.
|
||||
pub enum Stream {
|
||||
Plain(TcpStream),
|
||||
#[cfg(feature="tls")]
|
||||
Tls(TlsStream<TcpStream>),
|
||||
/// Stream mode, either plain TCP or TLS.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Mode {
|
||||
Plain,
|
||||
Tls,
|
||||
}
|
||||
|
||||
impl Read for Stream {
|
||||
/// Stream, either plain TCP or TLS.
|
||||
pub enum Stream<S, T> {
|
||||
Plain(S),
|
||||
Tls(T),
|
||||
}
|
||||
|
||||
impl<S: Read, T: Read> Read for Stream<S, T> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
|
||||
match *self {
|
||||
Stream::Plain(ref mut s) => s.read(buf),
|
||||
#[cfg(feature="tls")]
|
||||
Stream::Tls(ref mut s) => s.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for Stream {
|
||||
impl<S: Write, T: Write> Write for Stream<S, T> {
|
||||
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
|
||||
match *self {
|
||||
Stream::Plain(ref mut s) => s.write(buf),
|
||||
#[cfg(feature="tls")]
|
||||
Stream::Tls(ref mut s) => s.write(buf),
|
||||
}
|
||||
}
|
||||
fn flush(&mut self) -> IoResult<()> {
|
||||
match *self {
|
||||
Stream::Plain(ref mut s) => s.flush(),
|
||||
#[cfg(feature="tls")]
|
||||
Stream::Tls(ref mut s) => s.flush(),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue