server: let callback return HTTP error messages
Signed-off-by: Alexey Galakhov <agalakhov@snapview.de>
This commit is contained in:
parent
8ed73fd28a
commit
6f132208ee
|
@ -19,6 +19,7 @@ tls = ["native-tls"]
|
|||
base64 = "0.10.0"
|
||||
byteorder = "1.2.3"
|
||||
bytes = "0.4.8"
|
||||
http = "0.1.17"
|
||||
httparse = "1.3.1"
|
||||
input_buffer = "0.2.0"
|
||||
log = "0.4.2"
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
extern crate tungstenite;
|
||||
|
||||
use std::thread::spawn;
|
||||
use std::net::TcpListener;
|
||||
|
||||
use tungstenite::accept_hdr;
|
||||
use tungstenite::handshake::server::{Request, ErrorResponse};
|
||||
use tungstenite::http::StatusCode;
|
||||
|
||||
fn main() {
|
||||
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
|
||||
for stream in server.incoming() {
|
||||
spawn(move || {
|
||||
let callback = |_req: &Request| {
|
||||
Err(ErrorResponse {
|
||||
error_code: StatusCode::FORBIDDEN,
|
||||
headers: None,
|
||||
body: Some("Access denied".into()),
|
||||
})
|
||||
};
|
||||
accept_hdr(stream.unwrap(), callback).unwrap_err();
|
||||
});
|
||||
}
|
||||
}
|
|
@ -3,9 +3,11 @@
|
|||
use std::fmt::Write as FmtWrite;
|
||||
use std::io::{Read, Write};
|
||||
use std::marker::PhantomData;
|
||||
use std::result::Result as StdResult;
|
||||
|
||||
use httparse;
|
||||
use httparse::Status;
|
||||
use http::StatusCode;
|
||||
|
||||
use error::{Error, Result};
|
||||
use protocol::{WebSocket, WebSocketConfig, Role};
|
||||
|
@ -35,16 +37,21 @@ impl Request {
|
|||
Sec-WebSocket-Accept: {}\r\n",
|
||||
convert_key(key)?
|
||||
);
|
||||
if let Some(eh) = extra_headers {
|
||||
for (k, v) in eh {
|
||||
write!(reply, "{}: {}\r\n", k, v).unwrap();
|
||||
}
|
||||
}
|
||||
write!(reply, "\r\n").unwrap();
|
||||
add_headers(&mut reply, extra_headers);
|
||||
Ok(reply.into())
|
||||
}
|
||||
}
|
||||
|
||||
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<ExtraHeaders>) {
|
||||
if let Some(eh) = extra_headers {
|
||||
for (k, v) in eh {
|
||||
write!(reply, "{}: {}\r\n", k, v).unwrap();
|
||||
}
|
||||
}
|
||||
write!(reply, "\r\n").unwrap();
|
||||
}
|
||||
|
||||
|
||||
impl TryParse for Request {
|
||||
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
|
||||
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||
|
@ -71,6 +78,30 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
|
|||
}
|
||||
}
|
||||
|
||||
/// Extra headers for responses.
|
||||
pub type ExtraHeaders = Vec<(String, String)>;
|
||||
|
||||
/// An error response sent to the client.
|
||||
#[derive(Debug)]
|
||||
pub struct ErrorResponse {
|
||||
/// HTTP error code.
|
||||
pub error_code: StatusCode,
|
||||
/// Extra response headers, if any.
|
||||
pub headers: Option<ExtraHeaders>,
|
||||
/// REsponse body, if any.
|
||||
pub body: Option<String>,
|
||||
}
|
||||
|
||||
impl From<StatusCode> for ErrorResponse {
|
||||
fn from(error_code: StatusCode) -> Self {
|
||||
ErrorResponse {
|
||||
error_code,
|
||||
headers: None,
|
||||
body: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The callback trait.
|
||||
///
|
||||
/// The callback is called when the server receives an incoming WebSocket
|
||||
|
@ -81,11 +112,11 @@ pub trait Callback: Sized {
|
|||
/// Called whenever the server read the request from the client and is ready to reply to it.
|
||||
/// May return additional reply headers.
|
||||
/// Returning an error resulting in rejecting the incoming connection.
|
||||
fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>>;
|
||||
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse>;
|
||||
}
|
||||
|
||||
impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, String)>>> {
|
||||
fn on_request(self, request: &Request) -> Result<Option<Vec<(String, String)>>> {
|
||||
impl<F> Callback for F where F: FnOnce(&Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
|
||||
fn on_request(self, request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
|
||||
self(request)
|
||||
}
|
||||
}
|
||||
|
@ -95,7 +126,7 @@ impl<F> Callback for F where F: FnOnce(&Request) -> Result<Option<Vec<(String, S
|
|||
pub struct NoCallback;
|
||||
|
||||
impl Callback for NoCallback {
|
||||
fn on_request(self, _request: &Request) -> Result<Option<Vec<(String, String)>>> {
|
||||
fn on_request(self, _request: &Request) -> StdResult<Option<ExtraHeaders>, ErrorResponse> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
@ -110,6 +141,8 @@ pub struct ServerHandshake<S, C> {
|
|||
callback: Option<C>,
|
||||
/// WebSocket configuration.
|
||||
config: Option<WebSocketConfig>,
|
||||
/// Error code/flag. If set, an error will be returned after sending response to the client.
|
||||
error_code: Option<u16>,
|
||||
/// Internal stream type.
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
@ -123,7 +156,12 @@ impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
|
|||
trace!("Server handshake initiated.");
|
||||
MidHandshake {
|
||||
machine: HandshakeMachine::start_read(stream),
|
||||
role: ServerHandshake { callback: Some(callback), config, _marker: PhantomData },
|
||||
role: ServerHandshake {
|
||||
callback: Some(callback),
|
||||
config,
|
||||
error_code: None,
|
||||
_marker: PhantomData
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -141,24 +179,48 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
|
|||
if !tail.is_empty() {
|
||||
return Err(Error::Protocol("Junk after client request".into()))
|
||||
}
|
||||
let extra_headers = {
|
||||
if let Some(callback) = self.callback.take() {
|
||||
callback.on_request(&result)?
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
let callback_result = if let Some(callback) = self.callback.take() {
|
||||
callback.on_request(&result)
|
||||
} else {
|
||||
Ok(None)
|
||||
};
|
||||
let response = result.reply(extra_headers)?;
|
||||
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
|
||||
|
||||
match callback_result {
|
||||
Ok(extra_headers) => {
|
||||
let response = result.reply(extra_headers)?;
|
||||
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
|
||||
}
|
||||
|
||||
Err(ErrorResponse { error_code, headers, body }) => {
|
||||
self.error_code= Some(error_code.as_u16());
|
||||
let mut response = format!(
|
||||
"HTTP/1.1 {} {}\r\n",
|
||||
error_code.as_str(),
|
||||
error_code.canonical_reason().unwrap_or("")
|
||||
);
|
||||
add_headers(&mut response, headers);
|
||||
if let Some(body) = body {
|
||||
response += &body;
|
||||
}
|
||||
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StageResult::DoneWriting(stream) => {
|
||||
debug!("Server handshake done.");
|
||||
let websocket = WebSocket::from_raw_socket(
|
||||
stream,
|
||||
Role::Server,
|
||||
self.config.clone(),
|
||||
);
|
||||
ProcessingResult::Done(websocket)
|
||||
if let Some(err) = self.error_code.take() {
|
||||
debug!("Server handshake failed.");
|
||||
return Err(Error::Http(err));
|
||||
} else {
|
||||
debug!("Server handshake done.");
|
||||
let websocket = WebSocket::from_raw_socket(
|
||||
stream,
|
||||
Role::Server,
|
||||
self.config.clone(),
|
||||
);
|
||||
ProcessingResult::Done(websocket)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -22,6 +22,8 @@ extern crate url;
|
|||
extern crate utf8;
|
||||
#[cfg(feature="tls")] extern crate native_tls;
|
||||
|
||||
pub extern crate http;
|
||||
|
||||
pub mod error;
|
||||
pub mod protocol;
|
||||
pub mod client;
|
||||
|
|
Loading…
Reference in New Issue