client: overhaul of the request generation
This commit is contained in:
parent
1b999136ef
commit
d661f57224
|
@ -12,7 +12,7 @@ use log::*;
|
|||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
handshake::client::{Request, Response},
|
||||
handshake::client::{generate_key, Request, Response},
|
||||
protocol::WebSocketConfig,
|
||||
stream::MaybeTlsStream,
|
||||
};
|
||||
|
@ -178,7 +178,11 @@ where
|
|||
/// Trait for converting various types into HTTP requests used for a client connection.
|
||||
///
|
||||
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
|
||||
/// `http::Request<()>`.
|
||||
/// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
|
||||
/// simply take your request and pass it as is further without altering any headers or URLs, so
|
||||
/// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
|
||||
/// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
|
||||
/// the proper `http::Request<()>` for you.
|
||||
pub trait IntoClientRequest {
|
||||
/// Convert into a `Request` that can be used for a client connection.
|
||||
fn into_client_request(self) -> Result<Request>;
|
||||
|
@ -210,7 +214,26 @@ impl<'a> IntoClientRequest for &'a Uri {
|
|||
|
||||
impl IntoClientRequest for Uri {
|
||||
fn into_client_request(self) -> Result<Request> {
|
||||
Ok(Request::get(self).body(())?)
|
||||
let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
|
||||
let host = authority
|
||||
.find('@')
|
||||
.map(|idx| authority.split_at(idx + 1).1)
|
||||
.unwrap_or_else(|| authority);
|
||||
|
||||
if host.is_empty() {
|
||||
return Err(Error::Url(UrlError::EmptyHostName));
|
||||
}
|
||||
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.header("Host", host)
|
||||
.header("Connection", "Upgrade")
|
||||
.header("Upgrade", "websocket")
|
||||
.header("Sec-WebSocket-Version", "13")
|
||||
.header("Sec-WebSocket-Key", generate_key())
|
||||
.uri(self)
|
||||
.body(())?;
|
||||
Ok(req)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -167,9 +167,8 @@ pub enum ProtocolError {
|
|||
/// Custom responses must be unsuccessful.
|
||||
#[error("Custom response must not be successful")]
|
||||
CustomResponseSuccessful,
|
||||
/// Invalid header is passed. This header is formed by the library automatically
|
||||
/// and must not be overwritten by the user.
|
||||
#[error("Not allowed to pass overwrite the standard header {0}")]
|
||||
/// Invalid header is passed. Or the header is missing in the request. Or not present at all. Check the request that you pass.
|
||||
#[error("Missing, duplicated or incorrect header {0}")]
|
||||
InvalidHeader(HeaderName),
|
||||
/// No more data while still performing handshake.
|
||||
#[error("Handshake not finished")]
|
||||
|
|
|
@ -5,7 +5,9 @@ use std::{
|
|||
marker::PhantomData,
|
||||
};
|
||||
|
||||
use http::{header, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode};
|
||||
use http::{
|
||||
header::HeaderName, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
|
||||
};
|
||||
use httparse::Status;
|
||||
use log::*;
|
||||
|
||||
|
@ -52,12 +54,11 @@ impl<S: Read + Write> ClientHandshake<S> {
|
|||
// Check the URI scheme: only ws or wss are supported
|
||||
let _ = crate::client::uri_mode(request.uri())?;
|
||||
|
||||
let key = generate_key();
|
||||
// Convert and verify the `http::Request` and turn it into the request as per RFC.
|
||||
// Also extract the key from it (it must be present in a correct request).
|
||||
let (request, key) = generate_request(request)?;
|
||||
|
||||
let machine = {
|
||||
let req = generate_request(request, &key)?;
|
||||
HandshakeMachine::start_write(stream, req)
|
||||
};
|
||||
let machine = HandshakeMachine::start_write(stream, request);
|
||||
|
||||
let client = {
|
||||
let accept_key = derive_accept_key(key.as_ref());
|
||||
|
@ -92,56 +93,73 @@ impl<S: Read + Write> HandshakeRole for ClientHandshake<S> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Generate client request.
|
||||
fn generate_request(request: Request, key: &str) -> Result<Vec<u8>> {
|
||||
/// Verifies and generates a client WebSocket request from the original request and extracts a WebSocket key from it.
|
||||
fn generate_request(mut request: Request) -> Result<(Vec<u8>, String)> {
|
||||
let mut req = Vec::new();
|
||||
let uri = request.uri();
|
||||
|
||||
let authority = uri.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
|
||||
let host = if let Some(idx) = authority.find('@') {
|
||||
// handle possible name:password@
|
||||
authority.split_at(idx + 1).1
|
||||
} else {
|
||||
authority
|
||||
};
|
||||
if authority.is_empty() {
|
||||
return Err(Error::Url(UrlError::EmptyHostName));
|
||||
}
|
||||
|
||||
write!(
|
||||
req,
|
||||
"\
|
||||
GET {path} {version:?}\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: {key}\r\n",
|
||||
version = request.version(),
|
||||
host = host,
|
||||
path = uri.path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
|
||||
key = key
|
||||
"GET {path} {version:?}\r\n",
|
||||
path = request.uri().path_and_query().ok_or(Error::Url(UrlError::NoPathOrQuery))?.as_str(),
|
||||
version = request.version()
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
for (k, v) in request.headers() {
|
||||
if k == header::CONNECTION
|
||||
|| k == header::UPGRADE
|
||||
|| k == header::SEC_WEBSOCKET_VERSION
|
||||
|| k == header::SEC_WEBSOCKET_KEY
|
||||
|| k == header::HOST
|
||||
{
|
||||
// Headers that must be present in a correct request.
|
||||
const KEY_HEADERNAME: &str = "Sec-WebSocket-Key";
|
||||
const WEBSOCKET_HEADERS: [&str; 5] =
|
||||
["Host", "Connection", "Upgrade", "Sec-WebSocket-Version", KEY_HEADERNAME];
|
||||
|
||||
// We must extract a WebSocket key from a properly formed request or fail if it's not present.
|
||||
let key = request
|
||||
.headers()
|
||||
.get(KEY_HEADERNAME)
|
||||
.ok_or_else(|| {
|
||||
Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(KEY_HEADERNAME)))
|
||||
})?
|
||||
.to_str()?
|
||||
.to_owned();
|
||||
|
||||
// We must check that all necessary headers for a valid request are present. Note that we have to
|
||||
// deal with the fact that some apps seem to have a case-sensitive check for headers which is not
|
||||
// correct and should not considered the correct behavior, but it seems like some apps ignore it.
|
||||
// `http` by default writes all headers in lower-case which is fine (and does not violate the RFC)
|
||||
// but some servers seem to be poorely written and ignore RFC.
|
||||
//
|
||||
// See similar problem in `hyper`: https://github.com/hyperium/hyper/issues/1492
|
||||
let headers = request.headers_mut();
|
||||
for header in WEBSOCKET_HEADERS {
|
||||
let value = headers.remove(header).ok_or_else(|| {
|
||||
Error::Protocol(ProtocolError::InvalidHeader(HeaderName::from_static(header)))
|
||||
})?;
|
||||
write!(req, "{header}: {value}\r\n", value = value.to_str()?).unwrap();
|
||||
}
|
||||
|
||||
// Now we must ensure that the headers that we've written once are not anymore present in the map.
|
||||
// If they do, then the request is invalid (some headers are duplicated there for some reason).
|
||||
let insensitive: Vec<String> =
|
||||
WEBSOCKET_HEADERS.iter().map(|h| h.to_ascii_lowercase()).collect();
|
||||
for (k, v) in headers {
|
||||
let mut name = k.as_str();
|
||||
|
||||
// We have already written the necessary headers once (above) and removed them from the map.
|
||||
// If we encounter them again, then the request is considered invalid and error is returned.
|
||||
// Note that we can't use `.contains()`, since `&str` does not coerce to `&String` in Rust.
|
||||
if insensitive.iter().any(|x| x == name) {
|
||||
return Err(Error::Protocol(ProtocolError::InvalidHeader(k.clone())));
|
||||
}
|
||||
let mut k = k.as_str();
|
||||
if k == "sec-websocket-protocol" {
|
||||
k = "Sec-WebSocket-Protocol";
|
||||
|
||||
// Relates to the issue of some servers treating headers in a case-sensitive way, please see:
|
||||
// https://github.com/snapview/tungstenite-rs/pull/119 (original fix of the problem)
|
||||
if name == "sec-websocket-protocol" {
|
||||
name = "Sec-WebSocket-Protocol";
|
||||
}
|
||||
writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap();
|
||||
|
||||
writeln!(req, "{}: {}\r", name, v.to_str()?).unwrap();
|
||||
}
|
||||
|
||||
writeln!(req, "\r").unwrap();
|
||||
trace!("Request: {:?}", String::from_utf8_lossy(&req));
|
||||
Ok(req)
|
||||
Ok((req, key))
|
||||
}
|
||||
|
||||
/// Information for handshake verification.
|
||||
|
@ -241,7 +259,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
|
|||
}
|
||||
|
||||
/// Generate a random key for the `Sec-WebSocket-Key` header.
|
||||
fn generate_key() -> String {
|
||||
pub fn generate_key() -> String {
|
||||
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
||||
// when decoded, is 16 bytes in length (RFC 6455)
|
||||
let r: [u8; 16] = rand::random();
|
||||
|
@ -269,54 +287,41 @@ mod tests {
|
|||
assert!(k2[..22].find('=').is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_formatting() {
|
||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
|
||||
let correct = b"\
|
||||
fn construct_expected(host: &str, key: &str) -> Vec<u8> {
|
||||
format!(
|
||||
"\
|
||||
GET /getCaseCount HTTP/1.1\r\n\
|
||||
Host: localhost\r\n\
|
||||
Host: {host}\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
|
||||
\r\n";
|
||||
let request = generate_request(request, key).unwrap();
|
||||
println!("Request: {}", String::from_utf8_lossy(&request));
|
||||
Sec-WebSocket-Key: {key}\r\n\
|
||||
\r\n"
|
||||
)
|
||||
.into_bytes()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_formatting() {
|
||||
let request = "ws://localhost/getCaseCount".into_client_request().unwrap();
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let correct = construct_expected("localhost", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_formatting_with_host() {
|
||||
let request = "wss://localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
|
||||
let correct = b"\
|
||||
GET /getCaseCount HTTP/1.1\r\n\
|
||||
Host: localhost:9001\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
|
||||
\r\n";
|
||||
let request = generate_request(request, key).unwrap();
|
||||
println!("Request: {}", String::from_utf8_lossy(&request));
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let correct = construct_expected("localhost:9001", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_formatting_with_at() {
|
||||
let request = "wss://user:pass@localhost:9001/getCaseCount".into_client_request().unwrap();
|
||||
let key = "A70tsIbeMZUbJHh5BWFw6Q==";
|
||||
let correct = b"\
|
||||
GET /getCaseCount HTTP/1.1\r\n\
|
||||
Host: localhost:9001\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
Sec-WebSocket-Key: A70tsIbeMZUbJHh5BWFw6Q==\r\n\
|
||||
\r\n";
|
||||
let request = generate_request(request, key).unwrap();
|
||||
println!("Request: {}", String::from_utf8_lossy(&request));
|
||||
let (request, key) = generate_request(request).unwrap();
|
||||
let correct = construct_expected("localhost:9001", &key);
|
||||
assert_eq!(&request[..], &correct[..]);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue