From 9ff94f8bb3b425d9b173d74598ced2a708548477 Mon Sep 17 00:00:00 2001 From: "R. Tyler Croy" Date: Thu, 17 Sep 2020 15:34:22 -0700 Subject: [PATCH] Broadcast websocket messages around (lol) --- src/main.rs | 74 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/src/main.rs b/src/main.rs index 751f0d8..45e2a46 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,8 +2,11 @@ extern crate serde_json; use async_std::task; -use async_std::net::{TcpListener, TcpStream}; +use async_std::net::{SocketAddr, TcpListener, TcpStream}; +use async_tungstenite::tungstenite::protocol::Message; +use futures::pin_mut; use futures::prelude::*; +use futures::channel::mpsc::{unbounded, UnboundedSender}; use glob::glob; use handlebars::Handlebars; use log::*; @@ -11,9 +14,19 @@ use rodio::Source; use serde::Deserialize; use tide::{Body, Request, StatusCode}; +use std::collections::HashMap; use std::fs::File; use std::io::BufReader; +use std::sync::{Arc, Mutex}; + +type Tx = UnboundedSender; +type PeerMap = Arc>>; + + +/** + * Simple struct to deserialize some query parameters + */ #[derive(Debug, Deserialize)] struct Query { admin: String, @@ -68,22 +81,54 @@ async fn play(req: Request<()>) -> tide::Result { } } -async fn handle_websocket(stream: TcpStream) { - let addr = stream - .peer_addr() - .expect("connected streams should have a peer address"); - info!("Peer address: {}", addr); +async fn handle_websocket(peer_map: PeerMap, raw_stream: TcpStream, addr: SocketAddr) { + println!("Incoming TCP connection from: {}", addr); - let ws_stream = async_tungstenite::accept_async(stream) + let ws_stream = async_tungstenite::accept_async(raw_stream) .await .expect("Error during the websocket handshake occurred"); + println!("WebSocket connection established: {}", addr); - info!("New WebSocket connection: {}", addr); + // Insert the write part of this peer to the peer map. + let (tx, rx) = unbounded(); + peer_map.lock().unwrap().insert(addr, tx); - let (write, read) = ws_stream.split(); - read.forward(write) - .await - .expect("Failed to forward message") + let (outgoing, incoming) = ws_stream.split(); + + let broadcast_incoming = incoming + .try_filter(|msg| { + // Broadcasting a Close message from one client + // will close the other clients. + future::ready(!msg.is_close()) + }) + .try_for_each(|msg| { + println!( + "Received a message from {}: {}", + addr, + msg.to_text().unwrap() + ); + let peers = peer_map.lock().unwrap(); + + // We want to broadcast the message to everyone except ourselves. + let broadcast_recipients = peers + .iter() + .filter(|(peer_addr, _)| peer_addr != &&addr) + .map(|(_, ws_sink)| ws_sink); + + for recp in broadcast_recipients { + recp.unbounded_send(msg.clone()).unwrap(); + } + + future::ok(()) + }); + + let receive_from_others = rx.map(Ok).forward(outgoing); + + pin_mut!(broadcast_incoming, receive_from_others); + future::select(broadcast_incoming, receive_from_others).await; + + println!("{} disconnected", &addr); + peer_map.lock().unwrap().remove(&addr); } #[async_std::main] @@ -97,13 +142,14 @@ async fn main() -> Result<(), tide::Error> { task::spawn(async move { let addr = "0.0.0.0:9078"; + let state = PeerMap::new(Mutex::new(HashMap::new())); // Create the event loop and TCP listener we'll accept connections on. let try_socket = TcpListener::bind(&addr).await; let listener = try_socket.expect("Failed to bind"); info!("Listening on: {}", addr); - while let Ok((stream, _)) = listener.accept().await { - task::spawn(handle_websocket(stream)); + while let Ok((stream, addr)) = listener.accept().await { + task::spawn(handle_websocket(state.clone(), stream, addr)); } });