tungstenite-rs/src/handshake/mod.rs

136 lines
4.0 KiB
Rust

//! WebSocket handshake control.
pub mod client;
pub mod headers;
pub mod machine;
pub mod server;
use std::{
error::Error as ErrorTrait,
fmt,
io::{Read, Write},
};
use sha1::{Digest, Sha1};
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
use crate::error::Error;
/// A WebSocket handshake.
#[derive(Debug)]
pub struct MidHandshake<Role: HandshakeRole> {
role: Role,
machine: HandshakeMachine<Role::InternalStream>,
}
impl<Role: HandshakeRole> MidHandshake<Role> {
/// Allow access to machine
pub fn get_ref(&self) -> &HandshakeMachine<Role::InternalStream> {
&self.machine
}
/// Allow mutable access to machine
pub fn get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream> {
&mut self.machine
}
/// Restarts the handshake process.
pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
let mut mach = self.machine;
loop {
mach = match mach.single_round()? {
RoundResult::WouldBlock(m) => {
return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
}
RoundResult::Incomplete(m) => m,
RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
ProcessingResult::Continue(m) => m,
ProcessingResult::Done(result) => return Ok(result),
},
}
}
}
}
/// A handshake result.
pub enum HandshakeError<Role: HandshakeRole> {
/// Handshake was interrupted (would block).
Interrupted(MidHandshake<Role>),
/// Handshake failed.
Failure(Error),
}
impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
HandshakeError::Failure(ref e) => write!(f, "HandshakeError::Failure({:?})", e),
}
}
}
impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
HandshakeError::Failure(ref e) => write!(f, "{}", e),
}
}
}
impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {}
impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
fn from(err: Error) -> Self {
HandshakeError::Failure(err)
}
}
/// Handshake role.
pub trait HandshakeRole {
#[doc(hidden)]
type IncomingData: TryParse;
#[doc(hidden)]
type InternalStream: Read + Write;
#[doc(hidden)]
type FinalResult;
#[doc(hidden)]
fn stage_finished(
&mut self,
finish: StageResult<Self::IncomingData, Self::InternalStream>,
) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
}
/// Stage processing result.
#[doc(hidden)]
#[derive(Debug)]
pub enum ProcessingResult<Stream, FinalResult> {
Continue(HandshakeMachine<Stream>),
Done(FinalResult),
}
/// Derive the `Sec-WebSocket-Accept` response header from a `Sec-WebSocket-Key` request header.
///
/// This function can be used to perform a handshake before passing a raw TCP stream to
/// [`WebSocket::from_raw_socket`][crate::protocol::WebSocket::from_raw_socket].
pub fn derive_accept_key(request_key: &[u8]) -> String {
// ... field is constructed by concatenating /key/ ...
// ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
let mut sha1 = Sha1::default();
sha1.update(request_key);
sha1.update(WS_GUID);
data_encoding::BASE64.encode(&sha1.finalize())
}
#[cfg(test)]
mod tests {
use super::derive_accept_key;
#[test]
fn key_conversion() {
// example from RFC 6455
assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
}