diff --git a/src/client.rs b/src/client.rs index 9b30037..a1fd58e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,12 +1,13 @@ //! Methods to connect to a WebSocket as a client. use std::{ + convert::TryFrom, io::{Read, Write}, net::{SocketAddr, TcpStream, ToSocketAddrs}, result::Result as StdResult, }; -use http::{request::Parts, Uri}; +use http::{request::Parts, HeaderName, Uri}; use log::*; use url::Url; @@ -265,3 +266,73 @@ impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> { Request::from_httparse(self) } } + +/// Builder for a custom [`IntoClientRequest`] with options to add +/// custom additional headers and sub protocols. +/// +/// # Example +/// +/// ```rust no_run +/// # use crate::*; +/// use http::Uri; +/// use tungstenite::{connect, ClientRequestBuilder}; +/// +/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap(); +/// let token = "my_jwt_token"; +/// let builder = ClientRequestBuilder::new(uri) +/// .with_header("Authorization", format!("Bearer {token}")) +/// .with_sub_protocol("my_sub_protocol"); +/// let socket = connect(builder).unwrap(); +/// ``` +#[derive(Debug, Clone)] +pub struct ClientRequestBuilder { + uri: Uri, + /// Additional [`Request`] handshake headers + additional_headers: Vec<(String, String)>, + /// Handsake subprotocols + subprotocols: Vec, +} + +impl ClientRequestBuilder { + /// Initializes an empty request builder + #[must_use] + pub const fn new(uri: Uri) -> Self { + Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() } + } + + /// Adds (`key`, `value`) as an additional header to the handshake request + pub fn with_header(mut self, key: K, value: V) -> Self + where + K: Into, + V: Into, + { + self.additional_headers.push((key.into(), value.into())); + self + } + + /// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`) + pub fn with_sub_protocol

(mut self, protocol: P) -> Self + where + P: Into, + { + self.subprotocols.push(protocol.into()); + self + } +} + +impl IntoClientRequest for ClientRequestBuilder { + fn into_client_request(self) -> Result { + let mut request = self.uri.into_client_request()?; + let headers = request.headers_mut(); + for (k, v) in self.additional_headers { + let key = HeaderName::try_from(k)?; + let value = v.parse()?; + headers.append(key, value); + } + if !self.subprotocols.is_empty() { + let protocols = self.subprotocols.join(", ").parse()?; + headers.append("Sec-WebSocket-Protocol", protocols); + } + Ok(request) + } +} diff --git a/src/lib.rs b/src/lib.rs index 4fdf0a6..8c593cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,7 +39,7 @@ pub use crate::{ #[cfg(feature = "handshake")] pub use crate::{ - client::{client, connect}, + client::{client, connect, ClientRequestBuilder}, handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError}, server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config}, }; diff --git a/tests/client_headers.rs b/tests/client_headers.rs new file mode 100644 index 0000000..1a037af --- /dev/null +++ b/tests/client_headers.rs @@ -0,0 +1,92 @@ +#![cfg(feature = "handshake")] + +use http::Uri; +use std::{ + net::TcpListener, + process::exit, + thread::{sleep, spawn}, + time::Duration, +}; +use tungstenite::{ + accept_hdr, connect, + handshake::server::{Request, Response}, + ClientRequestBuilder, Error, Message, +}; + +/// Test for write buffering and flushing behaviour. +#[test] +fn test_headers() { + env_logger::init(); + let uri: Uri = "ws://127.0.0.1:3013/socket".parse().unwrap(); + let token = "my_jwt_token"; + let full_token = format!("Bearer {token}"); + let sub_protocol = "my_sub_protocol"; + let builder = ClientRequestBuilder::new(uri) + .with_header("Authorization", full_token.to_owned()) + .with_sub_protocol(sub_protocol.to_owned()); + + spawn(|| { + sleep(Duration::from_secs(5)); + println!("Unit test executed too long, perhaps stuck on WOULDBLOCK..."); + exit(1); + }); + + let server = TcpListener::bind("127.0.0.1:3013").unwrap(); + + let client_thread = spawn(move || { + let (mut client, _) = connect(builder).unwrap(); + client.send(Message::Text("Hello WebSocket".into())).unwrap(); + + let message = client.read().unwrap(); // receive close from server + assert!(message.is_close()); + + let err = client.read().unwrap_err(); // now we should get ConnectionClosed + match err { + Error::ConnectionClosed => {} + _ => panic!("unexpected error: {:?}", err), + } + }); + + let callback = |req: &Request, response: Response| { + println!("Received a new ws handshake"); + println!("The request's path is: {}", req.uri().path()); + println!("The request's headers are:"); + let authorization_header: String = "authorization".to_ascii_lowercase(); + let web_socket_proto: String = "sec-websocket-protocol".to_ascii_lowercase(); + + for (ref header, value) in req.headers() { + println!("* {}: {}", header, value.to_str().unwrap()); + if header.to_string() == authorization_header { + println!("Matching authorization header"); + assert_eq!(header.to_string(), authorization_header); + assert_eq!(value.to_str().unwrap(), full_token); + } else if header.to_string() == web_socket_proto { + println!("Matching sec-websocket-protocol header"); + assert_eq!(header.to_string(), web_socket_proto); + assert_eq!(value.to_str().unwrap(), sub_protocol); + } + } + Ok(response) + }; + + let client_handler = server.incoming().next().unwrap(); + let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap(); + + client_handler.close(None).unwrap(); // send close to client + + // This read should succeed even though we already initiated a close + let message = client_handler.read().unwrap(); + assert_eq!(message.into_data(), b"Hello WebSocket"); + + assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement + + let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed + match err { + Error::ConnectionClosed => {} + _ => panic!("unexpected error: {:?}", err), + } + + drop(client_handler); + + client_thread.join().unwrap(); +}