From 34f4774ce04c3f7f9c170945685b6d020faacf8c Mon Sep 17 00:00:00 2001 From: Stjepan Glavina Date: Sat, 20 Jun 2020 17:33:45 +0200 Subject: [PATCH] Add readable/writable --- src/async_io.rs | 157 ++++++++++++++++++++++++++++-------------------- src/reactor.rs | 123 ++++++++++++++++++++++--------------- 2 files changed, 169 insertions(+), 111 deletions(-) diff --git a/src/async_io.rs b/src/async_io.rs index 453ae03..83292aa 100644 --- a/src/async_io.rs +++ b/src/async_io.rs @@ -134,7 +134,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let listener = TcpListener::bind("127.0.0.1:80")?; + /// let listener = TcpListener::bind("127.0.0.1:0")?; /// let listener = Async::new(listener)?; /// # std::io::Result::Ok(()) }); /// ``` @@ -159,6 +159,7 @@ impl IntoRawFd for Async { self.into_inner().unwrap().into_raw_fd() } } + #[cfg(windows)] impl Async { /// Creates an async I/O handle. @@ -190,7 +191,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let listener = TcpListener::bind("127.0.0.1:80")?; + /// let listener = TcpListener::bind("127.0.0.1:0")?; /// let listener = Async::new(listener)?; /// # std::io::Result::Ok(()) }); /// ``` @@ -231,7 +232,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let listener = Async::::bind("127.0.0.1:80")?; + /// let listener = Async::::bind("127.0.0.1:0")?; /// let inner = listener.get_ref(); /// # std::io::Result::Ok(()) }); /// ``` @@ -248,7 +249,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let mut listener = Async::::bind("127.0.0.1:80")?; + /// let mut listener = Async::::bind("127.0.0.1:0")?; /// let inner = listener.get_mut(); /// # std::io::Result::Ok(()) }); /// ``` @@ -265,7 +266,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let listener = Async::::bind("127.0.0.1:80")?; + /// let listener = Async::::bind("127.0.0.1:0")?; /// let inner = listener.into_inner()?; /// # std::io::Result::Ok(()) }); /// ``` @@ -315,6 +316,48 @@ impl Async { } } + /// Waits until the I/O handle is readable. + /// + /// This function completes when a read operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ```no_run + /// use smol::Async; + /// use std::net::TcpListener; + /// + /// # smol::run(async { + /// let mut listener = Async::::bind("127.0.0.1:0")?; + /// + /// // Wait until a client can be accepted. + /// listener.readable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.source.readable().await + } + + /// Waits until the I/O handle is writable. + /// + /// This function completes when a write operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ```no_run + /// use smol::Async; + /// use std::net::TcpStream; + /// + /// # smol::run(async { + /// let stream = Async::::connect("example.com:80").await?; + /// + /// // Wait until the stream is writable. + /// stream.writable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.source.writable().await + } + /// Performs a read operation asynchronously. /// /// The I/O handle is registered in the reactor and put in non-blocking mode. This function @@ -331,7 +374,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let listener = Async::::bind("127.0.0.1:80")?; + /// let listener = Async::::bind("127.0.0.1:0")?; /// /// // Accept a new client asynchronously. /// let (stream, addr) = listener.read_with(|l| l.accept()).await?; @@ -339,13 +382,15 @@ impl Async { /// ``` pub async fn read_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { let mut op = op; - loop { + future::poll_fn(|cx| { match op(self.get_ref()) { Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, + res => return Poll::Ready(res), } - self.source.readable().await?; - } + futures_util::ready!(poll_future(cx, self.readable()))?; + Poll::Pending + }) + .await } /// Performs a read operation asynchronously. @@ -364,7 +409,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let mut listener = Async::::bind("127.0.0.1:80")?; + /// let mut listener = Async::::bind("127.0.0.1:0")?; /// /// // Accept a new client asynchronously. /// let (stream, addr) = listener.read_with_mut(|l| l.accept()).await?; @@ -375,13 +420,15 @@ impl Async { op: impl FnMut(&mut T) -> io::Result, ) -> io::Result { let mut op = op; - loop { + future::poll_fn(|cx| { match op(self.get_mut()) { Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, + res => return Poll::Ready(res), } - self.source.readable().await?; - } + futures_util::ready!(poll_future(cx, self.readable()))?; + Poll::Pending + }) + .await } /// Performs a write operation asynchronously. @@ -409,13 +456,15 @@ impl Async { /// ``` pub async fn write_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { let mut op = op; - loop { + future::poll_fn(|cx| { match op(self.get_ref()) { Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, + res => return Poll::Ready(res), } - self.source.writable().await?; - } + futures_util::ready!(poll_future(cx, self.writable()))?; + Poll::Pending + }) + .await } /// Performs a write operation asynchronously. @@ -446,13 +495,15 @@ impl Async { op: impl FnMut(&mut T) -> io::Result, ) -> io::Result { let mut op = op; - loop { + future::poll_fn(|cx| { match op(self.get_mut()) { Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - res => return res, + res => return Poll::Ready(res), } - self.source.writable().await?; - } + futures_util::ready!(poll_future(cx, self.writable()))?; + Poll::Pending + }) + .await } } @@ -468,12 +519,6 @@ impl Drop for Async { } } -/// Pins a future and then polls it. -fn poll_future(cx: &mut Context<'_>, fut: impl Future) -> Poll { - futures_util::pin_mut!(fut); - fut.poll(cx) -} - impl AsyncRead for Async { fn poll_read( mut self: Pin<&mut Self>, @@ -580,7 +625,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let listener = Async::::bind("127.0.0.1:80")?; + /// let listener = Async::::bind("127.0.0.1:0")?; /// println!("Listening on {}", listener.get_ref().local_addr()?); /// # std::io::Result::Ok(()) }); /// ``` @@ -604,7 +649,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let listener = Async::::bind("127.0.0.1:80")?; + /// let listener = Async::::bind("127.0.0.1:0")?; /// let (stream, addr) = listener.accept().await?; /// println!("Accepted client: {}", addr); /// # std::io::Result::Ok(()) }); @@ -626,7 +671,7 @@ impl Async { /// use std::net::TcpListener; /// /// # smol::run(async { - /// let listener = Async::::bind("127.0.0.1:80")?; + /// let listener = Async::::bind("127.0.0.1:0")?; /// let mut incoming = listener.incoming(); /// /// while let Some(stream) = incoming.next().await { @@ -691,22 +736,13 @@ impl Async { })?; let stream = Async::new(socket.into_tcp_stream())?; - // Waits for connect to complete. - let wait_connect = |mut stream: &TcpStream| match stream.write(&[]) { - Err(err) if err.kind() == io::ErrorKind::NotConnected => match stream.take_error()? { - Some(err) => Err(err), - None => Err(io::ErrorKind::WouldBlock.into()), - }, - res => res.map(|_| ()), - }; - // The stream becomes writable when connected. - match stream.write_with(|io| wait_connect(io)).await { - Ok(()) => Ok(stream), - Err(err) => match stream.get_ref().take_error()? { - Some(err) => Err(err), - None => Err(err), - }, + stream.writable().await?; + + // Check if there was an error while connecting. + match stream.get_ref().take_error()? { + None => Ok(stream), + Some(err) => Err(err), } } @@ -964,7 +1000,7 @@ impl Async { /// use std::os::unix::net::UnixListener; /// /// # smol::run(async { - /// let listener = Async::::bind("127.0.0.1:80")?; + /// let listener = Async::::bind("127.0.0.1:0")?; /// let mut incoming = listener.incoming(); /// /// while let Some(stream) = incoming.next().await { @@ -1014,23 +1050,10 @@ impl Async { })?; let stream = Async::new(socket.into_unix_stream())?; - // Waits for connect to complete. - let wait_connect = |mut stream: &UnixStream| match stream.write(&[]) { - Err(err) if err.kind() == io::ErrorKind::NotConnected => match stream.take_error()? { - Some(err) => Err(err), - None => Err(io::ErrorKind::WouldBlock.into()), - }, - res => res.map(|_| ()), - }; - // The stream becomes writable when connected. - match stream.write_with(|io| wait_connect(io)).await { - Ok(()) => Ok(stream), - Err(err) => match stream.get_ref().take_error()? { - Some(err) => Err(err), - None => Err(err), - }, - } + stream.writable().await?; + + Ok(stream) } /// Creates an unnamed pair of connected UDS stream sockets. @@ -1196,3 +1219,9 @@ impl Async { self.write_with(|io| io.send(buf)).await } } + +/// Pins a future and then polls it. +fn poll_future(cx: &mut Context<'_>, fut: impl Future) -> Poll { + futures_util::pin_mut!(fut); + fut.poll(cx) +} diff --git a/src/reactor.rs b/src/reactor.rs index 7ab7b9b..e30ccf3 100644 --- a/src/reactor.rs +++ b/src/reactor.rs @@ -25,7 +25,7 @@ use std::mem; use std::os::unix::io::RawFd; #[cfg(windows)] use std::os::windows::io::{FromRawSocket, RawSocket}; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::task::{Poll, Waker}; use std::time::{Duration, Instant}; @@ -52,6 +52,9 @@ pub(crate) struct Reactor { /// Raw bindings to epoll/kqueue/wepoll. sys: sys::Reactor, + /// Ticker bumped before polling. + ticker: AtomicU64, + /// Registered sources. sources: piper::Mutex>>, @@ -83,6 +86,7 @@ impl Reactor { pub fn get() -> &'static Reactor { static REACTOR: Lazy = Lazy::new(|| Reactor { sys: sys::Reactor::new().expect("cannot initialize I/O event notification"), + ticker: AtomicU64::new(0), sources: piper::Mutex::new(Slab::new()), events: piper::Mutex::new(sys::Events::new()), timers: piper::Mutex::new(BTreeMap::new()), @@ -122,6 +126,8 @@ impl Reactor { raw, key, wakers: piper::Mutex::new(Wakers { + tick_readable: 0, + tick_writable: 0, readers: Vec::new(), writers: Vec::new(), }), @@ -249,6 +255,13 @@ impl ReactorLock<'_> { (Some(a), Some(b)) => Some(a.min(b)), }; + // Bump the ticker before polling I/O. + let tick = self + .reactor + .ticker + .fetch_add(1, Ordering::SeqCst) + .wrapping_add(1); + // Block on I/O events. match self.reactor.sys.wait(&mut self.events, timeout) { // No I/O events occurred. @@ -273,11 +286,13 @@ impl ReactorLock<'_> { // Wake readers if a readability event was emitted. if ev.readable { + wakers.tick_readable = tick; ready.append(&mut wakers.readers); } // Wake writers if a writability event was emitted. if ev.writable { + wakers.tick_writable = tick; ready.append(&mut wakers.writers); } @@ -343,6 +358,12 @@ pub(crate) struct Source { /// Tasks interested in events on a source. #[derive(Debug)] struct Wakers { + /// Last reactor tick that delivered a readability event. + tick_readable: u64, + + /// Last reactor tick that delivered a writability event. + tick_writable: u64, + /// Tasks waiting for the next readability event. readers: Vec, @@ -361,69 +382,77 @@ impl Source { } /// Waits until the I/O source is readable. - /// - /// This function may occasionally complete even if the I/O source is not readable. pub(crate) async fn readable(&self) -> io::Result<()> { - let mut polled = false; + let mut tick = None; future::poll_fn(|cx| { - if polled { - Poll::Ready(Ok(())) - } else { - let mut wakers = self.wakers.lock(); + let mut wakers = self.wakers.lock(); - // If there are no other readers, re-register in the reactor. - if wakers.readers.is_empty() { - Reactor::get().sys.reregister( - self.raw, - self.key, - true, - !wakers.writers.is_empty(), - )?; + if let Some(tick) = tick { + if wakers.tick_readable > tick { + return Poll::Ready(Ok(())); } - - // Register the current task's waker if not present already. - if wakers.readers.iter().all(|w| !w.will_wake(cx.waker())) { - wakers.readers.push(cx.waker().clone()); - } - - polled = true; - Poll::Pending } + + // If there are no other readers, re-register in the reactor. + if wakers.readers.is_empty() { + Reactor::get().sys.reregister( + self.raw, + self.key, + true, + !wakers.writers.is_empty(), + )?; + } + + // Register the current task's waker if not present already. + if wakers.readers.iter().all(|w| !w.will_wake(cx.waker())) { + wakers.readers.push(cx.waker().clone()); + } + + // Remember the current tick. + if tick.is_none() { + tick = Some(Reactor::get().ticker.load(Ordering::SeqCst)); + } + + Poll::Pending }) .await } /// Waits until the I/O source is writable. - /// - /// This function may occasionally complete even if the I/O source is not writable. pub(crate) async fn writable(&self) -> io::Result<()> { - let mut polled = false; + let mut tick = None; future::poll_fn(|cx| { - if polled { - Poll::Ready(Ok(())) - } else { - let mut wakers = self.wakers.lock(); + let mut wakers = self.wakers.lock(); - // If there are no other writers, re-register in the reactor. - if wakers.writers.is_empty() { - Reactor::get().sys.reregister( - self.raw, - self.key, - !wakers.readers.is_empty(), - true, - )?; + if let Some(tick) = tick { + if wakers.tick_writable > tick { + return Poll::Ready(Ok(())); } - - // Register the current task's waker if not present already. - if wakers.writers.iter().all(|w| !w.will_wake(cx.waker())) { - wakers.writers.push(cx.waker().clone()); - } - - polled = true; - Poll::Pending } + + // If there are no other writers, re-register in the reactor. + if wakers.writers.is_empty() { + Reactor::get().sys.reregister( + self.raw, + self.key, + !wakers.readers.is_empty(), + true, + )?; + } + + // Register the current task's waker if not present already. + if wakers.writers.iter().all(|w| !w.will_wake(cx.waker())) { + wakers.writers.push(cx.waker().clone()); + } + + // Remember the current tick. + if tick.is_none() { + tick = Some(Reactor::get().ticker.load(Ordering::SeqCst)); + } + + Poll::Pending }) .await }