server: let callback return HTTP error messages

Signed-off-by: Alexey Galakhov <agalakhov@snapview.de>
This commit is contained in:
Alexey Galakhov 2019-05-04 02:15:43 +02:00
parent 8ed73fd28a
commit 6f132208ee
4 changed files with 115 additions and 26 deletions

View File

@ -19,6 +19,7 @@ tls = ["native-tls"]
base64 = "0.10.0"
byteorder = "1.2.3"
bytes = "0.4.8"
http = "0.1.17"
httparse = "1.3.1"
input_buffer = "0.2.0"
log = "0.4.2"

View File

@ -0,0 +1,24 @@
extern crate tungstenite;
use std::thread::spawn;
use std::net::TcpListener;
use tungstenite::accept_hdr;
use tungstenite::handshake::server::{Request, ErrorResponse};
use tungstenite::http::StatusCode;
fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() {
spawn(move || {
let callback = |_req: &Request| {
Err(ErrorResponse {
error_code: StatusCode::FORBIDDEN,
headers: None,
body: Some("Access denied".into()),
})
};
accept_hdr(stream.unwrap(), callback).unwrap_err();
});
}
}

View File

@ -3,9 +3,11 @@
use std::fmt::Write as FmtWrite;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::result::Result as StdResult;
use httparse;
use httparse::Status;
use http::StatusCode;
use error::{Error, Result};
use protocol::{WebSocket, WebSocketConfig, Role};
@ -35,16 +37,21 @@ impl Request {
Sec-WebSocket-Accept: {}\r\n",
convert_key(key)?
);
if let Some(eh) = extra_headers {
for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap();
}
}
write!(reply, "\r\n").unwrap();
add_headers(&mut reply, extra_headers);
Ok(reply.into())
}
}
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) {
if let Some(eh) = extra_headers {
for (k, v) in eh {
write!(reply, "{}: {}\r\n", k, v).unwrap();
}
}
write!(reply, "\r\n").unwrap();
}
impl TryParse for Request {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
@ -71,6 +78,30 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
}
}
/// Extra headers for responses.
pub type ExtraHeaders = Vec<(String, String)>;
/// An error response sent to the client.
#[derive(Debug)]
pub struct ErrorResponse {
/// HTTP error code.
pub error_code: StatusCode,
/// Extra response headers, if any.
pub headers: Option<ExtraHeaders>,
/// REsponse body, if any.
pub body: Option<String>,
}
impl From<StatusCode> for ErrorResponse {
fn from(error_code: StatusCode) -> Self {
ErrorResponse {
error_code,
headers: None,
body: None,
}
}
}
/// The callback trait.
///
/// The callback is called when the server receives an incoming WebSocket
@ -81,11 +112,11 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection.
fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>>;
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>;
}
impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, String)>>> {
fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>> {
impl<F> Callback for F where F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
self(request)
}
}
@ -95,7 +126,7 @@ impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, S
pub struct NoCallback;
impl Callback for NoCallback {
fn on_request(self, _request: &Request) -> Result<Option<Vec<(String, String)>>> {
fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
Ok(None)
}
}
@ -110,6 +141,8 @@ pub struct ServerHandshake<S, C> {
callback: Option<C>,
/// WebSocket configuration.
config: Option<WebSocketConfig>,
/// Error code/flag. If set, an error will be returned after sending response to the client.
error_code: Option<u16>,
/// Internal stream type.
_marker: PhantomData<S>,
}
@ -123,7 +156,12 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
trace!("Server handshake initiated.");
MidHandshake {
machine: HandshakeMachine::start_read(stream),
role: ServerHandshake { callback: Some(callback), config, _marker: PhantomData },
role: ServerHandshake {
callback: Some(callback),
config,
error_code: None,
_marker: PhantomData
},
}
}
}
@ -141,24 +179,48 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
if !tail.is_empty() {
return Err(Error::Protocol("Junk after client request".into()))
}
let extra_headers = {
if let Some(callback) = self.callback.take() {
callback.on_request(&result)?
} else {
None
}
let callback_result = if let Some(callback) = self.callback.take() {
callback.on_request(&result)
} else {
Ok(None)
};
let response = result.reply(extra_headers)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
match callback_result {
Ok(extra_headers) => {
let response = result.reply(extra_headers)?;
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}
Err(ErrorResponse { error_code, headers, body }) => {
self.error_code= Some(error_code.as_u16());
let mut response = format!(
"HTTP/1.1 {} {}\r\n",
error_code.as_str(),
error_code.canonical_reason().unwrap_or("")
);
add_headers(&mut response, headers);
if let Some(body) = body {
response += &body;
}
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}
}
}
StageResult::DoneWriting(stream) => {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.clone(),
);
ProcessingResult::Done(websocket)
if let Some(err) = self.error_code.take() {
debug!("Server handshake failed.");
return Err(Error::Http(err));
} else {
debug!("Server handshake done.");
let websocket = WebSocket::from_raw_socket(
stream,
Role::Server,
self.config.clone(),
);
ProcessingResult::Done(websocket)
}
}
})
}

View File

@ -22,6 +22,8 @@ extern crate url;
extern crate utf8;
#[cfg(feature="tls")] extern crate native_tls;
pub extern crate http;
pub mod error;
pub mod protocol;
pub mod client;