Compare commits
3 Commits
1d28846056
...
fc653b67d9
Author | SHA1 | Date |
---|---|---|
Conor | fc653b67d9 | |
Bartel Sielski | 60c50cdea2 | |
conorbros | a2ec84b822 |
20
src/error.rs
20
src/error.rs
|
@ -149,6 +149,23 @@ pub enum CapacityError {
|
|||
},
|
||||
}
|
||||
|
||||
/// Indicates the specific type/cause of a subprotocol header error.
|
||||
#[derive(Error, Clone, PartialEq, Eq, Debug, Copy)]
|
||||
pub enum SubProtocolError {
|
||||
/// The server sent a subprotocol to a client handshake request but none was requested
|
||||
#[error("Server sent a subprotocol but none was requested")]
|
||||
ServerSentSubProtocolNoneRequested,
|
||||
|
||||
/// The server sent an invalid subprotocol to a client handhshake request
|
||||
#[error("Server sent an invalid subprotocol")]
|
||||
InvalidSubProtocol,
|
||||
|
||||
/// The server sent no subprotocol to a client handshake request that requested one or more
|
||||
/// subprotocols
|
||||
#[error("Server sent no subprotocol")]
|
||||
NoSubProtocol,
|
||||
}
|
||||
|
||||
/// Indicates the specific type/cause of a protocol error.
|
||||
#[allow(missing_copy_implementations)]
|
||||
#[derive(Error, Debug, PartialEq, Eq, Clone)]
|
||||
|
@ -174,6 +191,9 @@ pub enum ProtocolError {
|
|||
/// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value.
|
||||
#[error("Key mismatch in \"Sec-WebSocket-Accept\" header")]
|
||||
SecWebSocketAcceptKeyMismatch,
|
||||
/// The `Sec-WebSocket-Protocol` header was invalid
|
||||
#[error("SubProtocol error: {0}")]
|
||||
SecWebSocketSubProtocolError(SubProtocolError),
|
||||
/// Garbage data encountered after client request.
|
||||
#[error("Junk after client request")]
|
||||
JunkAfterRequest,
|
||||
|
|
|
@ -18,7 +18,7 @@ use super::{
|
|||
HandshakeRole, MidHandshake, ProcessingResult,
|
||||
};
|
||||
use crate::{
|
||||
error::{Error, ProtocolError, Result, UrlError},
|
||||
error::{Error, ProtocolError, Result, SubProtocolError, UrlError},
|
||||
protocol::{Role, WebSocket, WebSocketConfig},
|
||||
};
|
||||
|
||||
|
@ -54,6 +54,8 @@ impl<S: Read + Write> ClientHandshake<S> {
|
|||
// Check the URI scheme: only ws or wss are supported
|
||||
let _ = crate::client::uri_mode(request.uri())?;
|
||||
|
||||
let subprotocols = extract_subprotocols_from_request(&request)?;
|
||||
|
||||
// Convert and verify the `http::Request` and turn it into the request as per RFC.
|
||||
// Also extract the key from it (it must be present in a correct request).
|
||||
let (request, key) = generate_request(request)?;
|
||||
|
@ -62,7 +64,11 @@ impl<S: Read + Write> ClientHandshake<S> {
|
|||
|
||||
let client = {
|
||||
let accept_key = derive_accept_key(key.as_ref());
|
||||
ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData }
|
||||
ClientHandshake {
|
||||
verify_data: VerifyData { accept_key, subprotocols },
|
||||
config,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Client handshake initiated.");
|
||||
|
@ -178,11 +184,22 @@ pub fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
|||
Ok((req, key))
|
||||
}
|
||||
|
||||
fn extract_subprotocols_from_request(request: &Request) -> Result<Option<Vec<String>>> {
|
||||
if let Some(subprotocols) = request.headers().get("Sec-WebSocket-Protocol") {
|
||||
Ok(Some(subprotocols.to_str()?.split(",").map(|s| s.to_string()).collect()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Information for handshake verification.
|
||||
#[derive(Debug)]
|
||||
struct VerifyData {
|
||||
/// Accepted server key.
|
||||
accept_key: String,
|
||||
|
||||
/// Accepted subprotocols
|
||||
subprotocols: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl VerifyData {
|
||||
|
@ -238,7 +255,27 @@ impl VerifyData {
|
|||
// not present in the client's handshake (the server has indicated a
|
||||
// subprotocol not requested by the client), the client MUST _Fail
|
||||
// the WebSocket Connection_. (RFC 6455)
|
||||
// TODO
|
||||
if headers.get("Sec-WebSocket-Protocol").is_none() && self.subprotocols.is_some() {
|
||||
return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
|
||||
SubProtocolError::NoSubProtocol,
|
||||
)));
|
||||
}
|
||||
|
||||
if headers.get("Sec-WebSocket-Protocol").is_some() && self.subprotocols.is_none() {
|
||||
return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
|
||||
SubProtocolError::ServerSentSubProtocolNoneRequested,
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(returned_subprotocol) = headers.get("Sec-WebSocket-Protocol") {
|
||||
if let Some(accepted_subprotocols) = &self.subprotocols {
|
||||
if !accepted_subprotocols.contains(&returned_subprotocol.to_str()?.to_string()) {
|
||||
return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
|
||||
SubProtocolError::InvalidSubProtocol,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
|
|
@ -86,10 +86,10 @@ pub fn create_response(request: &Request) -> Result<Response> {
|
|||
}
|
||||
|
||||
/// Create a response for the request with a custom body.
|
||||
pub fn create_response_with_body<T>(
|
||||
request: &HttpRequest<T>,
|
||||
generate_body: impl FnOnce() -> T,
|
||||
) -> Result<HttpResponse<T>> {
|
||||
pub fn create_response_with_body<T1, T2>(
|
||||
request: &HttpRequest<T1>,
|
||||
generate_body: impl FnOnce() -> T2,
|
||||
) -> Result<HttpResponse<T2>> {
|
||||
Ok(create_parts(request)?.body(generate_body())?)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
use std::net::TcpListener;
|
||||
use std::thread::spawn;
|
||||
use tungstenite::error::{Error, ProtocolError, SubProtocolError};
|
||||
use tungstenite::handshake::client::generate_key;
|
||||
use tungstenite::handshake::server::{Request, Response};
|
||||
use tungstenite::{accept_hdr, connect};
|
||||
|
||||
fn create_http_request(uri: &str, subprotocols: Option<Vec<String>>) -> http::Request<()> {
|
||||
let uri = uri.parse::<http::Uri>().unwrap();
|
||||
|
||||
let authority = uri.authority().unwrap().as_str();
|
||||
let host =
|
||||
authority.find('@').map(|idx| authority.split_at(idx + 1).1).unwrap_or_else(|| authority);
|
||||
|
||||
if host.is_empty() {
|
||||
panic!("Empty host name");
|
||||
}
|
||||
|
||||
let mut builder = http::Request::builder()
|
||||
.method("GET")
|
||||
.header("Host", host)
|
||||
.header("Connection", "Upgrade")
|
||||
.header("Upgrade", "websocket")
|
||||
.header("Sec-WebSocket-Version", "13")
|
||||
.header("Sec-WebSocket-Key", generate_key());
|
||||
|
||||
if let Some(subprotocols) = subprotocols {
|
||||
builder = builder.header("Sec-WebSocket-Protocol", subprotocols.join(","));
|
||||
}
|
||||
|
||||
builder.uri(uri).body(()).unwrap()
|
||||
}
|
||||
|
||||
fn server_thread(port: u16, server_subprotocols: Option<Vec<String>>) {
|
||||
spawn(move || {
|
||||
let server = TcpListener::bind(("127.0.0.1", port))
|
||||
.expect("Can't listen, is this port already in use?");
|
||||
let client_handler = server.incoming().next().unwrap();
|
||||
|
||||
let callback = |_request: &Request, mut response: Response| {
|
||||
if let Some(subprotocols) = server_subprotocols {
|
||||
let headers = response.headers_mut();
|
||||
headers.append("Sec-WebSocket-Protocol", subprotocols.join(",").parse().unwrap());
|
||||
}
|
||||
Ok(response)
|
||||
};
|
||||
|
||||
let _client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_send_no_subprotocol() {
|
||||
server_thread(3012, None);
|
||||
|
||||
let err =
|
||||
connect(create_http_request("ws://127.0.0.1:3012", Some(vec!["my-sub-protocol".into()])))
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
|
||||
SubProtocolError::NoSubProtocol
|
||||
))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_sent_subprotocol_none_requested() {
|
||||
server_thread(3013, Some(vec!["my-sub-protocol".to_string()]));
|
||||
|
||||
let err = connect(create_http_request("ws://127.0.0.1:3013", None)).unwrap_err();
|
||||
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
|
||||
SubProtocolError::ServerSentSubProtocolNoneRequested
|
||||
))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_subprotocol() {
|
||||
server_thread(3014, Some(vec!["invalid-sub-protocol".to_string()]));
|
||||
|
||||
let err = connect(create_http_request(
|
||||
"ws://127.0.0.1:3014",
|
||||
Some(vec!["my-sub-protocol".to_string()]),
|
||||
))
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(
|
||||
err,
|
||||
Error::Protocol(ProtocolError::SecWebSocketSubProtocolError(
|
||||
SubProtocolError::InvalidSubProtocol
|
||||
))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_multiple_subprotocols() {
|
||||
server_thread(3015, Some(vec!["my-sub-protocol".to_string()]));
|
||||
|
||||
let (_, response) = connect(create_http_request(
|
||||
"ws://127.0.0.1:3015",
|
||||
Some(vec![
|
||||
"my-sub-protocol".to_string(),
|
||||
"my-sub-protocol-1".to_string(),
|
||||
"my-sub-protocol-2".to_string(),
|
||||
]),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
response.headers().get("Sec-WebSocket-Protocol").unwrap(),
|
||||
"my-sub-protocol".parse::<http::HeaderValue>().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_single_subprotocol() {
|
||||
server_thread(3016, Some(vec!["my-sub-protocol".to_string()]));
|
||||
|
||||
let (_, response) = connect(create_http_request(
|
||||
"ws://127.0.0.1:3016",
|
||||
Some(vec!["my-sub-protocol".to_string()]),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
response.headers().get("Sec-WebSocket-Protocol").unwrap(),
|
||||
"my-sub-protocol".parse::<http::HeaderValue>().unwrap()
|
||||
);
|
||||
}
|
Loading…
Reference in New Issue