diff --git a/examples/hyper-client.rs b/examples/hyper-client.rs index bd00997..bb09f2a 100644 --- a/examples/hyper-client.rs +++ b/examples/hyper-client.rs @@ -8,7 +8,6 @@ use std::convert::TryInto; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; use anyhow::{bail, Context as _, Result}; @@ -23,7 +22,7 @@ use smol_macros::main; /// Sends a request and fetches the response. async fn fetch( - ex: &Arc>, + ex: &Executor<'static>, req: Request>, ) -> Result> { // Connect to the HTTP server. @@ -66,7 +65,7 @@ async fn fetch( } #[apply(main!)] -async fn main(ex: Arc>) -> Result<()> { +async fn main(ex: &Executor<'static>) -> Result<()> { // Create a request. let url: hyper::Uri = "https://www.rust-lang.org".try_into()?; let req = Request::builder() @@ -78,7 +77,7 @@ async fn main(ex: Arc>) -> Result<()> { .body(Empty::new())?; // Fetch the response. - let resp = fetch(&ex, req).await?; + let resp = fetch(ex, req).await?; println!("{:#?}", resp); // Read the message body. diff --git a/examples/hyper-server.rs b/examples/hyper-server.rs index d086d21..90473ae 100644 --- a/examples/hyper-server.rs +++ b/examples/hyper-server.rs @@ -13,24 +13,54 @@ //! //! Refer to `README.md` to see how to the TLS certificate was generated. -use std::net::{Shutdown, TcpListener, TcpStream}; +use std::net::{TcpListener, TcpStream}; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; -use anyhow::{Error, Result}; +use anyhow::Result; use async_native_tls::{Identity, TlsAcceptor, TlsStream}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Response, Server}; -use smol::{future, io, prelude::*, Async}; +use http_body_util::Full; +use hyper::body::Incoming; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use macro_rules_attribute::apply; +use smol::{future, io, prelude::*, Async, Executor}; +use smol_hyper::rt::{FuturesIo, SmolTimer}; +use smol_macros::main; /// Serves a request and returns a response. -async fn serve(req: Request, host: String) -> Result> { - println!("Serving {}{}", host, req.uri()); - Ok(Response::new(Body::from("Hello from hyper!"))) +async fn serve(req: Request) -> Result>> { + println!("Serving {}", req.uri()); + Ok(Response::new(Full::new("Hello from hyper!".as_bytes()))) +} + +/// Handle a new client. +async fn handle_client(client: Async, tls: Option) -> Result<()> { + // Wrap it in TLS if necessary. + let client = match &tls { + None => SmolStream::Plain(client), + Some(tls) => { + // In case of HTTPS, establish a secure TLS connection. + SmolStream::Tls(tls.accept(client).await?) + } + }; + + // Build the server. + hyper::server::conn::http1::Builder::new() + .timer(SmolTimer::new()) + .serve_connection(FuturesIo::new(client), service_fn(serve)) + .await?; + + Ok(()) } /// Listens for incoming connections and serves them. -async fn listen(listener: Async, tls: Option) -> Result<()> { +async fn listen( + ex: &Arc>, + listener: Async, + tls: Option, +) -> Result<()> { // Format the full host address. let host = &match tls { None => format!("http://{}", listener.get_ref().local_addr()?), @@ -38,86 +68,42 @@ async fn listen(listener: Async, tls: Option) -> Resul }; println!("Listening on {}", host); - // Start a hyper server. - Server::builder(SmolListener::new(&listener, tls)) - .executor(SmolExecutor) - .serve(make_service_fn(move |_| { - let host = host.clone(); - async { Ok::<_, Error>(service_fn(move |req| serve(req, host.clone()))) } - })) - .await?; + loop { + // Wait for a new client. + let (client, _) = listener.accept().await?; - Ok(()) + // Spawn a task to handle this connection. + ex.spawn({ + let tls = tls.clone(); + async move { + if let Err(e) = handle_client(client, tls).await { + println!("Error while handling client: {}", e); + } + } + }) + .detach(); + } } -fn main() -> Result<()> { +#[apply(main!)] +async fn main(ex: &Arc>) -> Result<()> { // Initialize TLS with the local certificate, private key, and password. let identity = Identity::from_pkcs12(include_bytes!("identity.pfx"), "password")?; let tls = TlsAcceptor::from(native_tls::TlsAcceptor::new(identity)?); // Start HTTP and HTTPS servers. - smol::block_on(async { - let http = listen(Async::::bind(([127, 0, 0, 1], 8000))?, None); - let https = listen( - Async::::bind(([127, 0, 0, 1], 8001))?, - Some(tls), - ); - future::try_zip(http, https).await?; - Ok(()) - }) -} - -/// Spawns futures. -#[derive(Clone)] -struct SmolExecutor; - -impl hyper::rt::Executor for SmolExecutor { - fn execute(&self, fut: F) { - smol::spawn(async { drop(fut.await) }).detach(); - } -} - -/// Listens for incoming connections. -struct SmolListener<'a> { - tls: Option, - incoming: Pin>> + Send + 'a>>, -} - -impl<'a> SmolListener<'a> { - fn new(listener: &'a Async, tls: Option) -> Self { - Self { - incoming: Box::pin(listener.incoming()), - tls, - } - } -} - -impl hyper::server::accept::Accept for SmolListener<'_> { - type Conn = SmolStream; - type Error = Error; - - fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll>> { - let stream = smol::ready!(self.incoming.as_mut().poll_next(cx)).unwrap()?; - - let stream = match &self.tls { - None => SmolStream::Plain(stream), - Some(tls) => { - // In case of HTTPS, start establishing a secure TLS connection. - let tls = tls.clone(); - SmolStream::Handshake(Box::pin(async move { - tls.accept(stream).await.map_err(|err| { - println!("Failed to establish secure TLS connection: {:#?}", err); - io::Error::new(io::ErrorKind::Other, Box::new(err)) - }) - })) - } - }; - - Poll::Ready(Some(Ok(stream))) - } + let http = listen( + ex, + Async::::bind(([127, 0, 0, 1], 8000))?, + None, + ); + let https = listen( + ex, + Async::::bind(([127, 0, 0, 1], 8001))?, + Some(tls), + ); + future::try_zip(http, https).await?; + Ok(()) } /// A TCP or TCP+TLS connection. @@ -127,83 +113,44 @@ enum SmolStream { /// A TCP connection secured by TLS. Tls(TlsStream>), - - /// A TCP connection that is in process of getting secured by TLS. - #[allow(clippy::type_complexity)] - Handshake(Pin>>> + Send>>), } -impl hyper::client::connect::Connection for SmolStream { - fn connected(&self) -> hyper::client::connect::Connected { - hyper::client::connect::Connected::new() - } -} - -impl tokio::io::AsyncRead for SmolStream { +impl AsyncRead for SmolStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - loop { - match &mut *self { - SmolStream::Plain(s) => { - return Pin::new(s) - .poll_read(cx, buf.initialize_unfilled()) - .map_ok(|size| { - buf.advance(size); - }); - } - SmolStream::Tls(s) => { - return Pin::new(s) - .poll_read(cx, buf.initialize_unfilled()) - .map_ok(|size| { - buf.advance(size); - }); - } - SmolStream::Handshake(f) => { - let s = smol::ready!(f.as_mut().poll(cx))?; - *self = SmolStream::Tls(s); - } - } + buf: &mut [u8], + ) -> Poll> { + match &mut *self { + Self::Plain(s) => Pin::new(s).poll_read(cx, buf), + Self::Tls(s) => Pin::new(s).poll_read(cx, buf), } } } -impl tokio::io::AsyncWrite for SmolStream { +impl AsyncWrite for SmolStream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - loop { - match &mut *self { - SmolStream::Plain(s) => return Pin::new(s).poll_write(cx, buf), - SmolStream::Tls(s) => return Pin::new(s).poll_write(cx, buf), - SmolStream::Handshake(f) => { - let s = smol::ready!(f.as_mut().poll(cx))?; - *self = SmolStream::Tls(s); - } - } + match &mut *self { + Self::Plain(s) => Pin::new(s).poll_write(cx, buf), + Self::Tls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + Self::Plain(s) => Pin::new(s).poll_close(cx), + Self::Tls(s) => Pin::new(s).poll_close(cx), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { - SmolStream::Plain(s) => Pin::new(s).poll_flush(cx), - SmolStream::Tls(s) => Pin::new(s).poll_flush(cx), - SmolStream::Handshake(_) => Poll::Ready(Ok(())), - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - SmolStream::Plain(s) => { - s.get_ref().shutdown(Shutdown::Write)?; - Poll::Ready(Ok(())) - } - SmolStream::Tls(s) => Pin::new(s).poll_close(cx), - SmolStream::Handshake(_) => Poll::Ready(Ok(())), + Self::Plain(s) => Pin::new(s).poll_close(cx), + Self::Tls(s) => Pin::new(s).poll_close(cx), } } }