Add readable/writable

This commit is contained in:
Stjepan Glavina 2020-06-20 17:33:45 +02:00
parent 69c04a9e42
commit 34f4774ce0
2 changed files with 169 additions and 111 deletions

View File

@ -134,7 +134,7 @@ impl<T: AsRawFd> Async<T> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let listener = TcpListener::bind("127.0.0.1:80")?;
/// let listener = TcpListener::bind("127.0.0.1:0")?;
/// let listener = Async::new(listener)?;
/// # std::io::Result::Ok(()) });
/// ```
@ -159,6 +159,7 @@ impl<T: IntoRawFd> IntoRawFd for Async<T> {
self.into_inner().unwrap().into_raw_fd()
}
}
#[cfg(windows)]
impl<T: AsRawSocket> Async<T> {
/// Creates an async I/O handle.
@ -190,7 +191,7 @@ impl<T: AsRawSocket> Async<T> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let listener = TcpListener::bind("127.0.0.1:80")?;
/// let listener = TcpListener::bind("127.0.0.1:0")?;
/// let listener = Async::new(listener)?;
/// # std::io::Result::Ok(()) });
/// ```
@ -231,7 +232,7 @@ impl<T> Async<T> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let listener = Async::<TcpListener>::bind("127.0.0.1:80")?;
/// let listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
/// let inner = listener.get_ref();
/// # std::io::Result::Ok(()) });
/// ```
@ -248,7 +249,7 @@ impl<T> Async<T> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let mut listener = Async::<TcpListener>::bind("127.0.0.1:80")?;
/// let mut listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
/// let inner = listener.get_mut();
/// # std::io::Result::Ok(()) });
/// ```
@ -265,7 +266,7 @@ impl<T> Async<T> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let listener = Async::<TcpListener>::bind("127.0.0.1:80")?;
/// let listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
/// let inner = listener.into_inner()?;
/// # std::io::Result::Ok(()) });
/// ```
@ -315,6 +316,48 @@ impl<T> Async<T> {
}
}
/// Waits until the I/O handle is readable.
///
/// This function completes when a read operation on this I/O handle wouldn't block.
///
/// # Examples
///
/// ```no_run
/// use smol::Async;
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let mut listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
///
/// // Wait until a client can be accepted.
/// listener.readable().await?;
/// # std::io::Result::Ok(()) });
/// ```
pub async fn readable(&self) -> io::Result<()> {
self.source.readable().await
}
/// Waits until the I/O handle is writable.
///
/// This function completes when a write operation on this I/O handle wouldn't block.
///
/// # Examples
///
/// ```no_run
/// use smol::Async;
/// use std::net::TcpStream;
///
/// # smol::run(async {
/// let stream = Async::<TcpStream>::connect("example.com:80").await?;
///
/// // Wait until the stream is writable.
/// stream.writable().await?;
/// # std::io::Result::Ok(()) });
/// ```
pub async fn writable(&self) -> io::Result<()> {
self.source.writable().await
}
/// Performs a read operation asynchronously.
///
/// The I/O handle is registered in the reactor and put in non-blocking mode. This function
@ -331,7 +374,7 @@ impl<T> Async<T> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let listener = Async::<TcpListener>::bind("127.0.0.1:80")?;
/// let listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
///
/// // Accept a new client asynchronously.
/// let (stream, addr) = listener.read_with(|l| l.accept()).await?;
@ -339,13 +382,15 @@ impl<T> Async<T> {
/// ```
pub async fn read_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
let mut op = op;
loop {
future::poll_fn(|cx| {
match op(self.get_ref()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
res => return Poll::Ready(res),
}
self.source.readable().await?;
}
futures_util::ready!(poll_future(cx, self.readable()))?;
Poll::Pending
})
.await
}
/// Performs a read operation asynchronously.
@ -364,7 +409,7 @@ impl<T> Async<T> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let mut listener = Async::<TcpListener>::bind("127.0.0.1:80")?;
/// let mut listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
///
/// // Accept a new client asynchronously.
/// let (stream, addr) = listener.read_with_mut(|l| l.accept()).await?;
@ -375,13 +420,15 @@ impl<T> Async<T> {
op: impl FnMut(&mut T) -> io::Result<R>,
) -> io::Result<R> {
let mut op = op;
loop {
future::poll_fn(|cx| {
match op(self.get_mut()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
res => return Poll::Ready(res),
}
self.source.readable().await?;
}
futures_util::ready!(poll_future(cx, self.readable()))?;
Poll::Pending
})
.await
}
/// Performs a write operation asynchronously.
@ -409,13 +456,15 @@ impl<T> Async<T> {
/// ```
pub async fn write_with<R>(&self, op: impl FnMut(&T) -> io::Result<R>) -> io::Result<R> {
let mut op = op;
loop {
future::poll_fn(|cx| {
match op(self.get_ref()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
res => return Poll::Ready(res),
}
self.source.writable().await?;
}
futures_util::ready!(poll_future(cx, self.writable()))?;
Poll::Pending
})
.await
}
/// Performs a write operation asynchronously.
@ -446,13 +495,15 @@ impl<T> Async<T> {
op: impl FnMut(&mut T) -> io::Result<R>,
) -> io::Result<R> {
let mut op = op;
loop {
future::poll_fn(|cx| {
match op(self.get_mut()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return res,
res => return Poll::Ready(res),
}
self.source.writable().await?;
}
futures_util::ready!(poll_future(cx, self.writable()))?;
Poll::Pending
})
.await
}
}
@ -468,12 +519,6 @@ impl<T> Drop for Async<T> {
}
}
/// Pins a future and then polls it.
fn poll_future<T>(cx: &mut Context<'_>, fut: impl Future<Output = T>) -> Poll<T> {
futures_util::pin_mut!(fut);
fut.poll(cx)
}
impl<T: Read> AsyncRead for Async<T> {
fn poll_read(
mut self: Pin<&mut Self>,
@ -580,7 +625,7 @@ impl Async<TcpListener> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let listener = Async::<TcpListener>::bind("127.0.0.1:80")?;
/// let listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
/// println!("Listening on {}", listener.get_ref().local_addr()?);
/// # std::io::Result::Ok(()) });
/// ```
@ -604,7 +649,7 @@ impl Async<TcpListener> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let listener = Async::<TcpListener>::bind("127.0.0.1:80")?;
/// let listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
/// let (stream, addr) = listener.accept().await?;
/// println!("Accepted client: {}", addr);
/// # std::io::Result::Ok(()) });
@ -626,7 +671,7 @@ impl Async<TcpListener> {
/// use std::net::TcpListener;
///
/// # smol::run(async {
/// let listener = Async::<TcpListener>::bind("127.0.0.1:80")?;
/// let listener = Async::<TcpListener>::bind("127.0.0.1:0")?;
/// let mut incoming = listener.incoming();
///
/// while let Some(stream) = incoming.next().await {
@ -691,22 +736,13 @@ impl Async<TcpStream> {
})?;
let stream = Async::new(socket.into_tcp_stream())?;
// Waits for connect to complete.
let wait_connect = |mut stream: &TcpStream| match stream.write(&[]) {
Err(err) if err.kind() == io::ErrorKind::NotConnected => match stream.take_error()? {
Some(err) => Err(err),
None => Err(io::ErrorKind::WouldBlock.into()),
},
res => res.map(|_| ()),
};
// The stream becomes writable when connected.
match stream.write_with(|io| wait_connect(io)).await {
Ok(()) => Ok(stream),
Err(err) => match stream.get_ref().take_error()? {
Some(err) => Err(err),
None => Err(err),
},
stream.writable().await?;
// Check if there was an error while connecting.
match stream.get_ref().take_error()? {
None => Ok(stream),
Some(err) => Err(err),
}
}
@ -964,7 +1000,7 @@ impl Async<UnixListener> {
/// use std::os::unix::net::UnixListener;
///
/// # smol::run(async {
/// let listener = Async::<UnixListener>::bind("127.0.0.1:80")?;
/// let listener = Async::<UnixListener>::bind("127.0.0.1:0")?;
/// let mut incoming = listener.incoming();
///
/// while let Some(stream) = incoming.next().await {
@ -1014,23 +1050,10 @@ impl Async<UnixStream> {
})?;
let stream = Async::new(socket.into_unix_stream())?;
// Waits for connect to complete.
let wait_connect = |mut stream: &UnixStream| match stream.write(&[]) {
Err(err) if err.kind() == io::ErrorKind::NotConnected => match stream.take_error()? {
Some(err) => Err(err),
None => Err(io::ErrorKind::WouldBlock.into()),
},
res => res.map(|_| ()),
};
// The stream becomes writable when connected.
match stream.write_with(|io| wait_connect(io)).await {
Ok(()) => Ok(stream),
Err(err) => match stream.get_ref().take_error()? {
Some(err) => Err(err),
None => Err(err),
},
}
stream.writable().await?;
Ok(stream)
}
/// Creates an unnamed pair of connected UDS stream sockets.
@ -1196,3 +1219,9 @@ impl Async<UnixDatagram> {
self.write_with(|io| io.send(buf)).await
}
}
/// Pins a future and then polls it.
fn poll_future<T>(cx: &mut Context<'_>, fut: impl Future<Output = T>) -> Poll<T> {
futures_util::pin_mut!(fut);
fut.poll(cx)
}

View File

@ -25,7 +25,7 @@ use std::mem;
use std::os::unix::io::RawFd;
#[cfg(windows)]
use std::os::windows::io::{FromRawSocket, RawSocket};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Poll, Waker};
use std::time::{Duration, Instant};
@ -52,6 +52,9 @@ pub(crate) struct Reactor {
/// Raw bindings to epoll/kqueue/wepoll.
sys: sys::Reactor,
/// Ticker bumped before polling.
ticker: AtomicU64,
/// Registered sources.
sources: piper::Mutex<Slab<Arc<Source>>>,
@ -83,6 +86,7 @@ impl Reactor {
pub fn get() -> &'static Reactor {
static REACTOR: Lazy<Reactor> = Lazy::new(|| Reactor {
sys: sys::Reactor::new().expect("cannot initialize I/O event notification"),
ticker: AtomicU64::new(0),
sources: piper::Mutex::new(Slab::new()),
events: piper::Mutex::new(sys::Events::new()),
timers: piper::Mutex::new(BTreeMap::new()),
@ -122,6 +126,8 @@ impl Reactor {
raw,
key,
wakers: piper::Mutex::new(Wakers {
tick_readable: 0,
tick_writable: 0,
readers: Vec::new(),
writers: Vec::new(),
}),
@ -249,6 +255,13 @@ impl ReactorLock<'_> {
(Some(a), Some(b)) => Some(a.min(b)),
};
// Bump the ticker before polling I/O.
let tick = self
.reactor
.ticker
.fetch_add(1, Ordering::SeqCst)
.wrapping_add(1);
// Block on I/O events.
match self.reactor.sys.wait(&mut self.events, timeout) {
// No I/O events occurred.
@ -273,11 +286,13 @@ impl ReactorLock<'_> {
// Wake readers if a readability event was emitted.
if ev.readable {
wakers.tick_readable = tick;
ready.append(&mut wakers.readers);
}
// Wake writers if a writability event was emitted.
if ev.writable {
wakers.tick_writable = tick;
ready.append(&mut wakers.writers);
}
@ -343,6 +358,12 @@ pub(crate) struct Source {
/// Tasks interested in events on a source.
#[derive(Debug)]
struct Wakers {
/// Last reactor tick that delivered a readability event.
tick_readable: u64,
/// Last reactor tick that delivered a writability event.
tick_writable: u64,
/// Tasks waiting for the next readability event.
readers: Vec<Waker>,
@ -361,69 +382,77 @@ impl Source {
}
/// Waits until the I/O source is readable.
///
/// This function may occasionally complete even if the I/O source is not readable.
pub(crate) async fn readable(&self) -> io::Result<()> {
let mut polled = false;
let mut tick = None;
future::poll_fn(|cx| {
if polled {
Poll::Ready(Ok(()))
} else {
let mut wakers = self.wakers.lock();
let mut wakers = self.wakers.lock();
// If there are no other readers, re-register in the reactor.
if wakers.readers.is_empty() {
Reactor::get().sys.reregister(
self.raw,
self.key,
true,
!wakers.writers.is_empty(),
)?;
if let Some(tick) = tick {
if wakers.tick_readable > tick {
return Poll::Ready(Ok(()));
}
// Register the current task's waker if not present already.
if wakers.readers.iter().all(|w| !w.will_wake(cx.waker())) {
wakers.readers.push(cx.waker().clone());
}
polled = true;
Poll::Pending
}
// If there are no other readers, re-register in the reactor.
if wakers.readers.is_empty() {
Reactor::get().sys.reregister(
self.raw,
self.key,
true,
!wakers.writers.is_empty(),
)?;
}
// Register the current task's waker if not present already.
if wakers.readers.iter().all(|w| !w.will_wake(cx.waker())) {
wakers.readers.push(cx.waker().clone());
}
// Remember the current tick.
if tick.is_none() {
tick = Some(Reactor::get().ticker.load(Ordering::SeqCst));
}
Poll::Pending
})
.await
}
/// Waits until the I/O source is writable.
///
/// This function may occasionally complete even if the I/O source is not writable.
pub(crate) async fn writable(&self) -> io::Result<()> {
let mut polled = false;
let mut tick = None;
future::poll_fn(|cx| {
if polled {
Poll::Ready(Ok(()))
} else {
let mut wakers = self.wakers.lock();
let mut wakers = self.wakers.lock();
// If there are no other writers, re-register in the reactor.
if wakers.writers.is_empty() {
Reactor::get().sys.reregister(
self.raw,
self.key,
!wakers.readers.is_empty(),
true,
)?;
if let Some(tick) = tick {
if wakers.tick_writable > tick {
return Poll::Ready(Ok(()));
}
// Register the current task's waker if not present already.
if wakers.writers.iter().all(|w| !w.will_wake(cx.waker())) {
wakers.writers.push(cx.waker().clone());
}
polled = true;
Poll::Pending
}
// If there are no other writers, re-register in the reactor.
if wakers.writers.is_empty() {
Reactor::get().sys.reregister(
self.raw,
self.key,
!wakers.readers.is_empty(),
true,
)?;
}
// Register the current task's waker if not present already.
if wakers.writers.iter().all(|w| !w.will_wake(cx.waker())) {
wakers.writers.push(cx.waker().clone());
}
// Remember the current tick.
if tick.is_none() {
tick = Some(Reactor::get().ticker.load(Ordering::SeqCst));
}
Poll::Pending
})
.await
}