Add builder for additional header values (#400)
* ADd builder for additional header values * Update client.rs * fix: docs * feat: add test * fix: typo * add --------- Co-authored-by: n4n5 <56606507+Its-Just-Nans@users.noreply.github.com> Co-authored-by: n4n5 <its.just.n4n5@gmail.com>
This commit is contained in:
parent
2ee05d1080
commit
0fa41973b4
|
@ -1,12 +1,13 @@
|
|||
//! Methods to connect to a WebSocket as a client.
|
||||
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
io::{Read, Write},
|
||||
net::{SocketAddr, TcpStream, ToSocketAddrs},
|
||||
result::Result as StdResult,
|
||||
};
|
||||
|
||||
use http::{request::Parts, Uri};
|
||||
use http::{request::Parts, HeaderName, Uri};
|
||||
use log::*;
|
||||
|
||||
use url::Url;
|
||||
|
@ -265,3 +266,73 @@ impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
|
|||
Request::from_httparse(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for a custom [`IntoClientRequest`] with options to add
|
||||
/// custom additional headers and sub protocols.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust no_run
|
||||
/// # use crate::*;
|
||||
/// use http::Uri;
|
||||
/// use tungstenite::{connect, ClientRequestBuilder};
|
||||
///
|
||||
/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
|
||||
/// let token = "my_jwt_token";
|
||||
/// let builder = ClientRequestBuilder::new(uri)
|
||||
/// .with_header("Authorization", format!("Bearer {token}"))
|
||||
/// .with_sub_protocol("my_sub_protocol");
|
||||
/// let socket = connect(builder).unwrap();
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientRequestBuilder {
|
||||
uri: Uri,
|
||||
/// Additional [`Request`] handshake headers
|
||||
additional_headers: Vec<(String, String)>,
|
||||
/// Handsake subprotocols
|
||||
subprotocols: Vec<String>,
|
||||
}
|
||||
|
||||
impl ClientRequestBuilder {
|
||||
/// Initializes an empty request builder
|
||||
#[must_use]
|
||||
pub const fn new(uri: Uri) -> Self {
|
||||
Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
|
||||
}
|
||||
|
||||
/// Adds (`key`, `value`) as an additional header to the handshake request
|
||||
pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
|
||||
where
|
||||
K: Into<String>,
|
||||
V: Into<String>,
|
||||
{
|
||||
self.additional_headers.push((key.into(), value.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
|
||||
pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
|
||||
where
|
||||
P: Into<String>,
|
||||
{
|
||||
self.subprotocols.push(protocol.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoClientRequest for ClientRequestBuilder {
|
||||
fn into_client_request(self) -> Result<Request> {
|
||||
let mut request = self.uri.into_client_request()?;
|
||||
let headers = request.headers_mut();
|
||||
for (k, v) in self.additional_headers {
|
||||
let key = HeaderName::try_from(k)?;
|
||||
let value = v.parse()?;
|
||||
headers.append(key, value);
|
||||
}
|
||||
if !self.subprotocols.is_empty() {
|
||||
let protocols = self.subprotocols.join(", ").parse()?;
|
||||
headers.append("Sec-WebSocket-Protocol", protocols);
|
||||
}
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ pub use crate::{
|
|||
|
||||
#[cfg(feature = "handshake")]
|
||||
pub use crate::{
|
||||
client::{client, connect},
|
||||
client::{client, connect, ClientRequestBuilder},
|
||||
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
|
||||
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
|
||||
};
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
#![cfg(feature = "handshake")]
|
||||
|
||||
use http::Uri;
|
||||
use std::{
|
||||
net::TcpListener,
|
||||
process::exit,
|
||||
thread::{sleep, spawn},
|
||||
time::Duration,
|
||||
};
|
||||
use tungstenite::{
|
||||
accept_hdr, connect,
|
||||
handshake::server::{Request, Response},
|
||||
ClientRequestBuilder, Error, Message,
|
||||
};
|
||||
|
||||
/// Test for write buffering and flushing behaviour.
|
||||
#[test]
|
||||
fn test_headers() {
|
||||
env_logger::init();
|
||||
let uri: Uri = "ws://127.0.0.1:3013/socket".parse().unwrap();
|
||||
let token = "my_jwt_token";
|
||||
let full_token = format!("Bearer {token}");
|
||||
let sub_protocol = "my_sub_protocol";
|
||||
let builder = ClientRequestBuilder::new(uri)
|
||||
.with_header("Authorization", full_token.to_owned())
|
||||
.with_sub_protocol(sub_protocol.to_owned());
|
||||
|
||||
spawn(|| {
|
||||
sleep(Duration::from_secs(5));
|
||||
println!("Unit test executed too long, perhaps stuck on WOULDBLOCK...");
|
||||
exit(1);
|
||||
});
|
||||
|
||||
let server = TcpListener::bind("127.0.0.1:3013").unwrap();
|
||||
|
||||
let client_thread = spawn(move || {
|
||||
let (mut client, _) = connect(builder).unwrap();
|
||||
client.send(Message::Text("Hello WebSocket".into())).unwrap();
|
||||
|
||||
let message = client.read().unwrap(); // receive close from server
|
||||
assert!(message.is_close());
|
||||
|
||||
let err = client.read().unwrap_err(); // now we should get ConnectionClosed
|
||||
match err {
|
||||
Error::ConnectionClosed => {}
|
||||
_ => panic!("unexpected error: {:?}", err),
|
||||
}
|
||||
});
|
||||
|
||||
let callback = |req: &Request, response: Response| {
|
||||
println!("Received a new ws handshake");
|
||||
println!("The request's path is: {}", req.uri().path());
|
||||
println!("The request's headers are:");
|
||||
let authorization_header: String = "authorization".to_ascii_lowercase();
|
||||
let web_socket_proto: String = "sec-websocket-protocol".to_ascii_lowercase();
|
||||
|
||||
for (ref header, value) in req.headers() {
|
||||
println!("* {}: {}", header, value.to_str().unwrap());
|
||||
if header.to_string() == authorization_header {
|
||||
println!("Matching authorization header");
|
||||
assert_eq!(header.to_string(), authorization_header);
|
||||
assert_eq!(value.to_str().unwrap(), full_token);
|
||||
} else if header.to_string() == web_socket_proto {
|
||||
println!("Matching sec-websocket-protocol header");
|
||||
assert_eq!(header.to_string(), web_socket_proto);
|
||||
assert_eq!(value.to_str().unwrap(), sub_protocol);
|
||||
}
|
||||
}
|
||||
Ok(response)
|
||||
};
|
||||
|
||||
let client_handler = server.incoming().next().unwrap();
|
||||
let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
|
||||
|
||||
client_handler.close(None).unwrap(); // send close to client
|
||||
|
||||
// This read should succeed even though we already initiated a close
|
||||
let message = client_handler.read().unwrap();
|
||||
assert_eq!(message.into_data(), b"Hello WebSocket");
|
||||
|
||||
assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement
|
||||
|
||||
let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed
|
||||
match err {
|
||||
Error::ConnectionClosed => {}
|
||||
_ => panic!("unexpected error: {:?}", err),
|
||||
}
|
||||
|
||||
drop(client_handler);
|
||||
|
||||
client_thread.join().unwrap();
|
||||
}
|
Loading…
Reference in New Issue