Add builder for additional header values (#400)

* ADd builder for additional header values

* Update client.rs

* fix: docs

* feat: add test

* fix: typo

* add

---------

Co-authored-by: n4n5 <56606507+Its-Just-Nans@users.noreply.github.com>
Co-authored-by: n4n5 <its.just.n4n5@gmail.com>
This commit is contained in:
Félix Lescaudey de Maneville 2024-02-12 20:56:15 +01:00 committed by GitHub
parent 2ee05d1080
commit 0fa41973b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 165 additions and 2 deletions

View File

@ -1,12 +1,13 @@
//! Methods to connect to a WebSocket as a client.
use std::{
convert::TryFrom,
io::{Read, Write},
net::{SocketAddr, TcpStream, ToSocketAddrs},
result::Result as StdResult,
};
use http::{request::Parts, Uri};
use http::{request::Parts, HeaderName, Uri};
use log::*;
use url::Url;
@ -265,3 +266,73 @@ impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
Request::from_httparse(self)
}
}
/// Builder for a custom [`IntoClientRequest`] with options to add
/// custom additional headers and sub protocols.
///
/// # Example
///
/// ```rust no_run
/// # use crate::*;
/// use http::Uri;
/// use tungstenite::{connect, ClientRequestBuilder};
///
/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
/// let token = "my_jwt_token";
/// let builder = ClientRequestBuilder::new(uri)
/// .with_header("Authorization", format!("Bearer {token}"))
/// .with_sub_protocol("my_sub_protocol");
/// let socket = connect(builder).unwrap();
/// ```
#[derive(Debug, Clone)]
pub struct ClientRequestBuilder {
uri: Uri,
/// Additional [`Request`] handshake headers
additional_headers: Vec<(String, String)>,
/// Handsake subprotocols
subprotocols: Vec<String>,
}
impl ClientRequestBuilder {
/// Initializes an empty request builder
#[must_use]
pub const fn new(uri: Uri) -> Self {
Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
}
/// Adds (`key`, `value`) as an additional header to the handshake request
pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
where
K: Into<String>,
V: Into<String>,
{
self.additional_headers.push((key.into(), value.into()));
self
}
/// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
where
P: Into<String>,
{
self.subprotocols.push(protocol.into());
self
}
}
impl IntoClientRequest for ClientRequestBuilder {
fn into_client_request(self) -> Result<Request> {
let mut request = self.uri.into_client_request()?;
let headers = request.headers_mut();
for (k, v) in self.additional_headers {
let key = HeaderName::try_from(k)?;
let value = v.parse()?;
headers.append(key, value);
}
if !self.subprotocols.is_empty() {
let protocols = self.subprotocols.join(", ").parse()?;
headers.append("Sec-WebSocket-Protocol", protocols);
}
Ok(request)
}
}

View File

@ -39,7 +39,7 @@ pub use crate::{
#[cfg(feature = "handshake")]
pub use crate::{
client::{client, connect},
client::{client, connect, ClientRequestBuilder},
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
};

92
tests/client_headers.rs Normal file
View File

@ -0,0 +1,92 @@
#![cfg(feature = "handshake")]
use http::Uri;
use std::{
net::TcpListener,
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{
accept_hdr, connect,
handshake::server::{Request, Response},
ClientRequestBuilder, Error, Message,
};
/// Test for write buffering and flushing behaviour.
#[test]
fn test_headers() {
env_logger::init();
let uri: Uri = "ws://127.0.0.1:3013/socket".parse().unwrap();
let token = "my_jwt_token";
let full_token = format!("Bearer {token}");
let sub_protocol = "my_sub_protocol";
let builder = ClientRequestBuilder::new(uri)
.with_header("Authorization", full_token.to_owned())
.with_sub_protocol(sub_protocol.to_owned());
spawn(|| {
sleep(Duration::from_secs(5));
println!("Unit test executed too long, perhaps stuck on WOULDBLOCK...");
exit(1);
});
let server = TcpListener::bind("127.0.0.1:3013").unwrap();
let client_thread = spawn(move || {
let (mut client, _) = connect(builder).unwrap();
client.send(Message::Text("Hello WebSocket".into())).unwrap();
let message = client.read().unwrap(); // receive close from server
assert!(message.is_close());
let err = client.read().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});
let callback = |req: &Request, response: Response| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
let authorization_header: String = "authorization".to_ascii_lowercase();
let web_socket_proto: String = "sec-websocket-protocol".to_ascii_lowercase();
for (ref header, value) in req.headers() {
println!("* {}: {}", header, value.to_str().unwrap());
if header.to_string() == authorization_header {
println!("Matching authorization header");
assert_eq!(header.to_string(), authorization_header);
assert_eq!(value.to_str().unwrap(), full_token);
} else if header.to_string() == web_socket_proto {
println!("Matching sec-websocket-protocol header");
assert_eq!(header.to_string(), web_socket_proto);
assert_eq!(value.to_str().unwrap(), sub_protocol);
}
}
Ok(response)
};
let client_handler = server.incoming().next().unwrap();
let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();
client_handler.close(None).unwrap(); // send close to client
// This read should succeed even though we already initiated a close
let message = client_handler.read().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");
assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement
let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
drop(client_handler);
client_thread.join().unwrap();
}