From baf8968624b7aaa3abeda430991196938593d08b Mon Sep 17 00:00:00 2001 From: Stjepan Glavina Date: Sun, 28 Jun 2020 22:15:55 +0200 Subject: [PATCH] Initial commit --- .gitignore | 2 + Cargo.toml | 26 + src/lib.rs | 1295 ++++++++++++++++++++++++++++++++++++++++++++++++ src/parking.rs | 1182 +++++++++++++++++++++++++++++++++++++++++++ src/sys.rs | 277 +++++++++++ tests/async.rs | 339 +++++++++++++ tests/timer.rs | 27 + 7 files changed, 3148 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/lib.rs create mode 100644 src/parking.rs create mode 100644 src/sys.rs create mode 100644 tests/async.rs create mode 100644 tests/timer.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..1bd0d09 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "async-io" +version = "0.1.0" +authors = ["Stjepan Glavina "] +edition = "2018" + +[dependencies] +blocking = "0.4.6" +concurrent-queue = "1.1.1" +futures-io = { version = "0.3.5", default-features = false, features = ["std"] } +futures-util = { version = "0.3.5", default-features = false, features = ["std", "io"] } +libc = "0.2.71" +once_cell = "1.4.0" +parking = "1.0.3" +slab = "0.4.2" +socket2 = { version = "0.3.12", features = ["pair", "unix"] } + +[target.'cfg(windows)'.dependencies] +wepoll-sys-stjepang = "1.0.6" +winapi = { version = "0.3.8", features = ["ioapiset"] } + +[dev-dependencies] +async-channel = "1.1.1" +async-dup = "1.1.0" +futures = { version = "0.3.5", default-features = false, features = ["std"] } +tempfile = "3.1.0" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0b8a8f7 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,1295 @@ +use std::fmt::Debug; +use std::future::Future; +use std::io::{self, IoSlice, IoSliceMut, Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; +#[cfg(unix)] +use std::{ + os::unix::io::{AsRawFd, RawFd}, + os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream}, + path::Path, +}; + +use futures_io::{AsyncRead, AsyncWrite}; +use futures_util::future; +use futures_util::stream::{self, Stream}; +use socket2::{Domain, Protocol, Socket, Type}; + +use crate::parking::{Reactor, Source}; + +pub mod parking; +mod sys; + +/// Fires at the chosen point in time. +/// +/// Timers are futures that output the [`Instant`] at which they fired. +/// +/// # Examples +/// +/// Sleep for 1 second: +/// +/// ``` +/// use async_io::Timer; +/// use std::time::Duration; +/// +/// async fn sleep(dur: Duration) { +/// Timer::after(dur).await; +/// } +/// +/// # blocking::block_on(async { +/// sleep(Duration::from_secs(1)).await; +/// # }); +/// ``` +/// +/// Set a timeout on an I/O operation: +/// +/// ``` +/// use async_io::Timer; +/// use blocking::Unblock; +/// use futures::future::Either; +/// use futures::io::{self, BufReader}; +/// use futures::prelude::*; +/// use std::time::Duration; +/// +/// async fn timeout( +/// dur: Duration, +/// f: impl Future>, +/// ) -> io::Result { +/// futures::pin_mut!(f); +/// match future::select(f, Timer::after(dur)).await { +/// Either::Left((out, _)) => out, +/// Either::Right(_) => Err(io::ErrorKind::TimedOut.into()), +/// } +/// } +/// +/// # blocking::block_on(async { +/// // Create a buffered stdin reader. +/// let mut stdin = BufReader::new(Unblock::new(std::io::stdin())); +/// +/// // Read a line within 5 seconds. +/// let mut line = String::new(); +/// timeout(Duration::from_secs(5), stdin.read_line(&mut line)).await?; +/// # io::Result::Ok(()) }); +/// ``` +#[derive(Debug)] +pub struct Timer { + /// This timer's ID. + /// + /// When this field is set to `None`, this timer is not registered in the reactor. + id: Option, + + /// When this timer fires. + when: Instant, +} + +impl Timer { + /// Fires after the specified duration of time. + /// + /// # Examples + /// + /// ``` + /// use async_io::Timer; + /// use std::time::Duration; + /// + /// # blocking::block_on(async { + /// Timer::after(Duration::from_secs(1)).await; + /// # }); + /// ``` + pub fn after(dur: Duration) -> Timer { + Timer::at(Instant::now() + dur) + } + + /// Fires at the specified instant in time. + /// + /// # Examples + /// + /// ``` + /// use async_io::Timer; + /// use std::time::{Duration, Instant}; + /// + /// # blocking::block_on(async { + /// let now = Instant::now(); + /// let when = now + Duration::from_secs(1); + /// Timer::at(when).await; + /// # }); + /// ``` + pub fn at(when: Instant) -> Timer { + let id = None; + Timer { id, when } + } +} + +impl Drop for Timer { + fn drop(&mut self) { + if let Some(id) = self.id.take() { + // Deregister the timer from the reactor. + Reactor::get().remove_timer(self.when, id); + } + } +} + +impl Future for Timer { + type Output = Instant; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Check if the timer has already fired. + if Instant::now() >= self.when { + if let Some(id) = self.id.take() { + // Deregister the timer from the reactor. + Reactor::get().remove_timer(self.when, id); + } + Poll::Ready(self.when) + } else { + if self.id.is_none() { + // Register the timer in the reactor. + self.id = Some(Reactor::get().insert_timer(self.when, cx.waker())); + } + Poll::Pending + } + } +} + +/// Async I/O. +/// +/// This type converts a blocking I/O type into an async type, provided it is supported by +/// [epoll]/[kqueue]/[wepoll]. +/// +/// I/O operations can then be *asyncified* by methods [`Async::with()`] and [`Async::with_mut()`], +/// or you can use the predefined async methods on the standard networking types. +/// +/// **NOTE**: Do not use this type with [`File`][`std::fs::File`], [`Stdin`][`std::io::Stdin`], +/// [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] because they're not +/// supported. Use [`reader()`][`crate::reader()`] and [`writer()`][`crate::writer()`] functions +/// instead to read/write on a thread. +/// +/// # Examples +/// +/// To make an async I/O handle cloneable, wrap it in [async-dup]'s `Arc`: +/// +/// ```no_run +/// use async_dup::Arc; +/// use async_io::Async; +/// use std::net::TcpStream; +/// +/// # blocking::block_on(async { +/// // Connect to a local server. +/// let stream = Async::::connect("127.0.0.1:8000").await?; +/// +/// // Create two handles to the stream. +/// let reader = Arc::new(stream); +/// let mut writer = reader.clone(); +/// +/// // Echo all messages from the read side of the stream into the write side. +/// futures::io::copy(reader, &mut writer).await?; +/// # std::io::Result::Ok(()) }); +/// ``` +/// +/// If a type does but its reference doesn't implement [`AsyncRead`] and [`AsyncWrite`], wrap it in +/// [async-dup]'s `Mutex`: +/// +/// ```no_run +/// use async_dup::{Arc, Mutex}; +/// use async_io::Async; +/// use futures::prelude::*; +/// use std::net::TcpStream; +/// +/// # blocking::block_on(async { +/// // Reads data from a stream and echoes it back. +/// async fn echo(stream: impl AsyncRead + AsyncWrite + Unpin) -> std::io::Result { +/// let stream = Mutex::new(stream); +/// +/// // Create two handles to the stream. +/// let reader = Arc::new(stream); +/// let mut writer = reader.clone(); +/// +/// // Echo all messages from the read side of the stream into the write side. +/// futures::io::copy(reader, &mut writer).await +/// } +/// +/// // Connect to a local server and echo its messages back. +/// let stream = Async::::connect("127.0.0.1:8000").await?; +/// echo(stream).await?; +/// # std::io::Result::Ok(()) }); +/// ``` +/// +/// [async-dup]: https://docs.rs/async-dup +/// [epoll]: https://en.wikipedia.org/wiki/Epoll +/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue +/// [wepoll]: https://github.com/piscisaureus/wepoll +#[derive(Debug)] +pub struct Async { + /// A source registered in the reactor. + source: Arc, + + /// The inner I/O handle. + io: Option>, +} + +#[cfg(unix)] +impl Async { + /// Creates an async I/O handle. + /// + /// This function will put the handle in non-blocking mode and register it in [epoll] on + /// Linux/Android, [kqueue] on macOS/iOS/BSD, or [wepoll] on Windows. + /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement + /// `AsRawSocket`. + /// + /// If the handle implements [`Read`] and [`Write`], then `Async` automatically + /// implements [`AsyncRead`] and [`AsyncWrite`]. + /// Other I/O operations can be *asyncified* by methods [`Async::with()`] and + /// [`Async::with_mut()`]. + /// + /// **NOTE**: Do not use this type with [`File`][`std::fs::File`], [`Stdin`][`std::io::Stdin`], + /// [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] because they're not + /// supported by [epoll]/[kqueue]/[wepoll]. + /// Use [`reader()`][`crate::reader()`] and [`writer()`][`crate::writer()`] functions instead + /// to read/write on a thread. + /// + /// [epoll]: https://en.wikipedia.org/wiki/Epoll + /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue + /// [wepoll]: https://github.com/piscisaureus/wepoll + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let listener = TcpListener::bind("127.0.0.1:0")?; + /// let listener = Async::new(listener)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn new(io: T) -> io::Result> { + Ok(Async { + source: Reactor::get().insert_io(io.as_raw_fd())?, + io: Some(Box::new(io)), + }) + } +} + +#[cfg(unix)] +impl AsRawFd for Async { + fn as_raw_fd(&self) -> RawFd { + self.source.raw + } +} + +#[cfg(windows)] +impl Async { + /// Creates an async I/O handle. + /// + /// This function will put the handle in non-blocking mode and register it in [epoll] on + /// Linux/Android, [kqueue] on macOS/iOS/BSD, or [wepoll] on Windows. + /// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement + /// `AsRawSocket`. + /// + /// If the handle implements [`Read`] and [`Write`], then `Async` automatically + /// implements [`AsyncRead`] and [`AsyncWrite`]. + /// Other I/O operations can be *asyncified* by methods [`Async::with()`] and + /// [`Async::with_mut()`]. + /// + /// **NOTE**: Do not use this type with [`File`][`std::fs::File`], [`Stdin`][`std::io::Stdin`], + /// [`Stdout`][`std::io::Stdout`], or [`Stderr`][`std::io::Stderr`] because they're not + /// supported by epoll/kqueue/wepoll. + /// Use [`reader()`][`crate::reader()`] and [`writer()`][`crate::writer()`] functions instead + /// to read/write on a thread. + /// + /// [epoll]: https://en.wikipedia.org/wiki/Epoll + /// [kqueue]: https://en.wikipedia.org/wiki/Kqueue + /// [wepoll]: https://github.com/piscisaureus/wepoll + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let listener = TcpListener::bind("127.0.0.1:0")?; + /// let listener = Async::new(listener)?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn new(io: T) -> io::Result> { + Ok(Async { + source: Reactor::get().insert_io(io.as_raw_socket())?, + io: Some(Box::new(io)), + }) + } +} + +#[cfg(windows)] +impl AsRawSocket for Async { + fn as_raw_socket(&self) -> RawSocket { + self.source.raw + } +} + +impl Async { + /// Gets a reference to the inner I/O handle. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("127.0.0.1:0")?; + /// let inner = listener.get_ref(); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn get_ref(&self) -> &T { + self.io.as_ref().unwrap() + } + + /// Gets a mutable reference to the inner I/O handle. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let mut listener = Async::::bind("127.0.0.1:0")?; + /// let inner = listener.get_mut(); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn get_mut(&mut self) -> &mut T { + self.io.as_mut().unwrap() + } + + /// Unwraps the inner non-blocking I/O handle. + /// + /// # Examples + /// + /// ``` + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("127.0.0.1:0")?; + /// let inner = listener.into_inner()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn into_inner(mut self) -> io::Result { + let io = *self.io.take().unwrap(); + Reactor::get().remove_io(&self.source)?; + Ok(io) + } + + /// Waits until the I/O handle is readable. + /// + /// This function completes when a read operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let mut listener = Async::::bind("127.0.0.1:0")?; + /// + /// // Wait until a client can be accepted. + /// listener.readable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.source.readable().await + } + + /// Waits until the I/O handle is writable. + /// + /// This function completes when a write operation on this I/O handle wouldn't block. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpStream; + /// + /// # blocking::block_on(async { + /// let stream = Async::::connect("example.com:80").await?; + /// + /// // Wait until the stream is writable. + /// stream.writable().await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.source.writable().await + } + + /// Performs a read operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This function + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is readable. + /// + /// The closure receives a shared reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("127.0.0.1:0")?; + /// + /// // Accept a new client asynchronously. + /// let (stream, addr) = listener.read_with(|l| l.accept()).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn read_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { + let mut op = op; + future::poll_fn(|cx| { + match op(self.get_ref()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + futures_util::ready!(poll_once(cx, self.readable()))?; + Poll::Pending + }) + .await + } + + /// Performs a read operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This function + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is readable. + /// + /// The closure receives a mutable reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let mut listener = Async::::bind("127.0.0.1:0")?; + /// + /// // Accept a new client asynchronously. + /// let (stream, addr) = listener.read_with_mut(|l| l.accept()).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn read_with_mut( + &mut self, + op: impl FnMut(&mut T) -> io::Result, + ) -> io::Result { + let mut op = op; + future::poll_fn(|cx| { + match op(self.get_mut()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + futures_util::ready!(poll_once(cx, self.readable()))?; + Poll::Pending + }) + .await + } + + /// Performs a write operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This function + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is writable. + /// + /// The closure receives a shared reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("127.0.0.1:9000")?; + /// socket.get_ref().connect("127.0.0.1:8000")?; + /// + /// let msg = b"hello"; + /// let len = socket.write_with(|s| s.send(msg)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn write_with(&self, op: impl FnMut(&T) -> io::Result) -> io::Result { + let mut op = op; + future::poll_fn(|cx| { + match op(self.get_ref()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + futures_util::ready!(poll_once(cx, self.writable()))?; + Poll::Pending + }) + .await + } + + /// Performs a write operation asynchronously. + /// + /// The I/O handle is registered in the reactor and put in non-blocking mode. This function + /// invokes the `op` closure in a loop until it succeeds or returns an error other than + /// [`io::ErrorKind::WouldBlock`]. In between iterations of the loop, it waits until the OS + /// sends a notification that the I/O handle is writable. + /// + /// The closure receives a mutable reference to the I/O handle. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let mut socket = Async::::bind("127.0.0.1:9000")?; + /// socket.get_ref().connect("127.0.0.1:8000")?; + /// + /// let msg = b"hello"; + /// let len = socket.write_with_mut(|s| s.send(msg)).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn write_with_mut( + &mut self, + op: impl FnMut(&mut T) -> io::Result, + ) -> io::Result { + let mut op = op; + future::poll_fn(|cx| { + match op(self.get_mut()) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), + } + futures_util::ready!(poll_once(cx, self.writable()))?; + Poll::Pending + }) + .await + } +} + +impl Drop for Async { + fn drop(&mut self) { + if self.io.is_some() { + // Deregister and ignore errors because destructors should not panic. + let _ = Reactor::get().remove_io(&self.source); + + // Drop the I/O handle to close it. + self.io.take(); + } + } +} + +impl AsyncRead for Async { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + poll_once(cx, self.read_with_mut(|io| io.read(buf))) + } + + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + poll_once(cx, self.read_with_mut(|io| io.read_vectored(bufs))) + } +} + +impl AsyncRead for &Async +where + for<'a> &'a T: Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + poll_once(cx, self.read_with(|io| (&*io).read(buf))) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + poll_once(cx, self.read_with(|io| (&*io).read_vectored(bufs))) + } +} + +impl AsyncWrite for Async { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + poll_once(cx, self.write_with_mut(|io| io.write(buf))) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + poll_once(cx, self.write_with_mut(|io| io.write_vectored(bufs))) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_once(cx, self.write_with_mut(|io| io.flush())) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +impl AsyncWrite for &Async +where + for<'a> &'a T: Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + poll_once(cx, self.write_with(|io| (&*io).write(buf))) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + poll_once(cx, self.write_with(|io| (&*io).write_vectored(bufs))) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_once(cx, self.write_with(|io| (&*io).flush())) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +impl Async { + /// Creates a TCP listener bound to the specified address. + /// + /// Binding with port number 0 will request an available port from the OS. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("127.0.0.1:0")?; + /// println!("Listening on {}", listener.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind(addr: A) -> io::Result> { + let addr = addr + .to_string() + .parse::() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + Ok(Async::new(TcpListener::bind(addr)?)?) + } + + /// Accepts a new incoming TCP connection. + /// + /// When a connection is established, it will be returned as a TCP stream together with its + /// remote address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("127.0.0.1:0")?; + /// let (stream, addr) = listener.accept().await?; + /// println!("Accepted client: {}", addr); + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn accept(&self) -> io::Result<(Async, SocketAddr)> { + let (stream, addr) = self.read_with(|io| io.accept()).await?; + Ok((Async::new(stream)?, addr)) + } + + /// Returns a stream of incoming TCP connections. + /// + /// The stream is infinite, i.e. it never stops with a [`None`] item. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use futures::prelude::*; + /// use std::net::TcpListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("127.0.0.1:0")?; + /// let mut incoming = listener.incoming(); + /// + /// while let Some(stream) = incoming.next().await { + /// let stream = stream?; + /// println!("Accepted client: {}", stream.get_ref().peer_addr()?); + /// } + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn incoming(&self) -> impl Stream>> + Send + Unpin + '_ { + Box::pin(stream::unfold(self, |listener| async move { + let res = listener.accept().await.map(|(stream, _)| stream); + Some((res, listener)) + })) + } +} + +impl Async { + /// Creates a TCP connection to the specified address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpStream; + /// + /// # blocking::block_on(async { + /// let stream = Async::::connect("example.com:80").await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn connect(addr: A) -> io::Result> { + let addr = addr.to_string(); + let addr = blocking::unblock(move || { + addr.to_socket_addrs()?.next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "could not resolve the address") + }) + }) + .await?; + + // Create a socket. + let domain = if addr.is_ipv6() { + Domain::ipv6() + } else { + Domain::ipv4() + }; + let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?; + + // Begin async connect and ignore the inevitable "in progress" error. + socket.set_nonblocking(true)?; + socket.connect(&addr.into()).or_else(|err| { + // Check for EINPROGRESS on Unix and WSAEWOULDBLOCK on Windows. + #[cfg(unix)] + let in_progress = err.raw_os_error() == Some(libc::EINPROGRESS); + #[cfg(windows)] + let in_progress = err.kind() == io::ErrorKind::WouldBlock; + + // If connect results with an "in progress" error, that's not an error. + if in_progress { + Ok(()) + } else { + Err(err) + } + })?; + let stream = Async::new(socket.into_tcp_stream())?; + + // The stream becomes writable when connected. + stream.writable().await?; + + // Check if there was an error while connecting. + match stream.get_ref().take_error()? { + None => Ok(stream), + Some(err) => Err(err), + } + } + + /// Reads data from the stream without removing it from the buffer. + /// + /// Returns the number of bytes read. Successive calls of this method read the same data. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::TcpStream; + /// + /// # blocking::block_on(async { + /// let stream = Async::::connect("127.0.0.1:8080").await?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = stream.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.peek(buf)).await + } +} + +impl Async { + /// Creates a UDP socket bound to the specified address. + /// + /// Binding with port number 0 will request an available port from the OS. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("127.0.0.1:9000")?; + /// println!("Bound to {}", socket.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind(addr: A) -> io::Result> { + let addr = addr + .to_string() + .parse::() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + Ok(Async::new(UdpSocket::bind(addr)?)?) + } + + /// Receives a single datagram message. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("127.0.0.1:9000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.recv_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.read_with(|io| io.recv_from(buf)).await + } + + /// Receives a single datagram message without removing it from the queue. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("127.0.0.1:9000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.peek_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.read_with(|io| io.peek_from(buf)).await + } + + /// Sends data to the specified address. + /// + /// Returns the number of bytes writen. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("127.0.0.1:9000")?; + /// + /// let msg = b"hello"; + /// let addr = ([127, 0, 0, 1], 8000); + /// let len = socket.send_to(msg, addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send_to>(&self, buf: &[u8], addr: A) -> io::Result { + let addr = addr.into(); + self.write_with(|io| io.send_to(buf, addr)).await + } + + /// Receives a single datagram message from the connected peer. + /// + /// Returns the number of bytes read. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("127.0.0.1:9000")?; + /// socket.get_ref().connect("127.0.0.1:8000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.recv(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.recv(buf)).await + } + + /// Receives a single datagram message from the connected peer without removing it from the + /// queue. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// This method must be called with a valid byte slice of sufficient size to hold the message. + /// If the message is too long to fit, excess bytes may get discarded. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("127.0.0.1:9000")?; + /// socket.get_ref().connect("127.0.0.1:8000")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.peek(buf)).await + } + + /// Sends data to the connected peer. + /// + /// Returns the number of bytes written. + /// + /// The [`connect`][`UdpSocket::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::net::UdpSocket; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("127.0.0.1:9000")?; + /// socket.get_ref().connect("127.0.0.1:8000")?; + /// + /// let msg = b"hello"; + /// let len = socket.send(msg).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + self.write_with(|io| io.send(buf)).await + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS listener bound to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("/tmp/socket")?; + /// println!("Listening on {:?}", listener.get_ref().local_addr()?); + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(path: P) -> io::Result> { + let path = path.as_ref().to_owned(); + Ok(Async::new(UnixListener::bind(path)?)?) + } + + /// Accepts a new incoming UDS stream connection. + /// + /// When a connection is established, it will be returned as a stream together with its remote + /// address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("/tmp/socket")?; + /// let (stream, addr) = listener.accept().await?; + /// println!("Accepted client: {:?}", addr); + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn accept(&self) -> io::Result<(Async, UnixSocketAddr)> { + let (stream, addr) = self.read_with(|io| io.accept()).await?; + Ok((Async::new(stream)?, addr)) + } + + /// Returns a stream of incoming UDS connections. + /// + /// The stream is infinite, i.e. it never stops with a [`None`] item. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use futures::prelude::*; + /// use std::os::unix::net::UnixListener; + /// + /// # blocking::block_on(async { + /// let listener = Async::::bind("127.0.0.1:0")?; + /// let mut incoming = listener.incoming(); + /// + /// while let Some(stream) = incoming.next().await { + /// let stream = stream?; + /// println!("Accepted client: {:?}", stream.get_ref().peer_addr()?); + /// } + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn incoming( + &self, + ) -> impl Stream>> + Send + Unpin + '_ { + Box::pin(stream::unfold(self, |listener| async move { + let res = listener.accept().await.map(|(stream, _)| stream); + Some((res, listener)) + })) + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS stream connected to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixStream; + /// + /// # blocking::block_on(async { + /// let stream = Async::::connect("/tmp/socket").await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn connect>(path: P) -> io::Result> { + // Create a socket. + let socket = Socket::new(Domain::unix(), Type::stream(), None)?; + + // Begin async connect and ignore the inevitable "in progress" error. + socket.set_nonblocking(true)?; + socket + .connect(&socket2::SockAddr::unix(path)?) + .or_else(|err| { + if err.raw_os_error() == Some(libc::EINPROGRESS) { + Ok(()) + } else { + Err(err) + } + })?; + let stream = Async::new(socket.into_unix_stream())?; + + // The stream becomes writable when connected. + stream.writable().await?; + + Ok(stream) + } + + /// Creates an unnamed pair of connected UDS stream sockets. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixStream; + /// + /// # blocking::block_on(async { + /// let (stream1, stream2) = Async::::pair()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn pair() -> io::Result<(Async, Async)> { + let (stream1, stream2) = UnixStream::pair()?; + Ok((Async::new(stream1)?, Async::new(stream2)?)) + } +} + +#[cfg(unix)] +impl Async { + /// Creates a UDS datagram socket bound to the specified path. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("/tmp/socket")?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn bind>(path: P) -> io::Result> { + let path = path.as_ref().to_owned(); + Ok(Async::new(UnixDatagram::bind(path)?)?) + } + + /// Creates a UDS datagram socket not bound to any address. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # blocking::block_on(async { + /// let socket = Async::::unbound()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn unbound() -> io::Result> { + Ok(Async::new(UnixDatagram::unbound()?)?) + } + + /// Creates an unnamed pair of connected Unix datagram sockets. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # blocking::block_on(async { + /// let (socket1, socket2) = Async::::pair()?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub fn pair() -> io::Result<(Async, Async)> { + let (socket1, socket2) = UnixDatagram::pair()?; + Ok((Async::new(socket1)?, Async::new(socket2)?)) + } + + /// Receives data from the socket. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("/tmp/socket")?; + /// + /// let mut buf = [0u8; 1024]; + /// let (len, addr) = socket.recv_from(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, UnixSocketAddr)> { + self.read_with(|io| io.recv_from(buf)).await + } + + /// Sends data to the specified address. + /// + /// Returns the number of bytes written. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # blocking::block_on(async { + /// let socket = Async::::unbound()?; + /// + /// let msg = b"hello"; + /// let addr = "/tmp/socket"; + /// let len = socket.send_to(msg, addr).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send_to>(&self, buf: &[u8], path: P) -> io::Result { + self.write_with(|io| io.send_to(buf, &path)).await + } + + /// Receives data from the connected peer. + /// + /// Returns the number of bytes read and the address the message came from. + /// + /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("/tmp/socket1")?; + /// socket.get_ref().connect("/tmp/socket2")?; + /// + /// let mut buf = [0u8; 1024]; + /// let len = socket.recv(&mut buf).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + self.read_with(|io| io.recv(buf)).await + } + + /// Sends data to the connected peer. + /// + /// Returns the number of bytes written. + /// + /// The [`connect`][`UnixDatagram::connect()`] method connects this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// # Examples + /// + /// ```no_run + /// use async_io::Async; + /// use std::os::unix::net::UnixDatagram; + /// + /// # blocking::block_on(async { + /// let socket = Async::::bind("/tmp/socket1")?; + /// socket.get_ref().connect("/tmp/socket2")?; + /// + /// let msg = b"hello"; + /// let len = socket.send(msg).await?; + /// # std::io::Result::Ok(()) }); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + self.write_with(|io| io.send(buf)).await + } +} + +/// Pins a future and then polls it. +fn poll_once(cx: &mut Context<'_>, fut: impl Future) -> Poll { + futures_util::pin_mut!(fut); + fut.poll(cx) +} diff --git a/src/parking.rs b/src/parking.rs new file mode 100644 index 0000000..cc609ac --- /dev/null +++ b/src/parking.rs @@ -0,0 +1,1182 @@ +//! Thread parking and unparking. +//! +//! This module exposes the same API as [`parking`](https://docs.rs/parking). The only difference +//! is that [`Parker`] in this module will wait on epoll/kqueue/wepoll and wake tasks blocked on +//! I/O or timers. + +#[cfg(not(any( + target_os = "linux", // epoll + target_os = "android", // epoll + target_os = "illumos", // epoll + target_os = "macos", // kqueue + target_os = "ios", // kqueue + target_os = "freebsd", // kqueue + target_os = "netbsd", // kqueue + target_os = "openbsd", // kqueue + target_os = "dragonfly", // kqueue + target_os = "windows", // wepoll +)))] +compile_error!("async-io does not support this target OS"); + +use std::collections::BTreeMap; +use std::fmt; +use std::io; +use std::mem; +#[cfg(unix)] +use std::os::unix::io::RawFd; +#[cfg(windows)] +use std::os::windows::io::RawSocket; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Condvar, Mutex, MutexGuard}; +use std::task::{Poll, Waker}; +use std::thread; +use std::time::{Duration, Instant}; + +use concurrent_queue::ConcurrentQueue; +use futures_util::future; +use once_cell::sync::Lazy; +use slab::Slab; + +static PARKER_COUNT: AtomicUsize = AtomicUsize::new(0); + +/// Parks a thread. +pub struct Parker { + unparker: Unparker, +} + +impl Parker { + /// Creates a new [`Parker`]. + pub fn new() -> Parker { + let parker = Parker { + unparker: Unparker { + inner: Arc::new(Inner { + state: AtomicUsize::new(EMPTY), + lock: Mutex::new(()), + cvar: Condvar::new(), + }), + }, + }; + PARKER_COUNT.fetch_add(1, Ordering::SeqCst); + parker + } + + /// Blocks the current thread until the token is made available. + pub fn park(&self) { + self.unparker.inner.park(None); + } + + /// Blocks the current thread until the token is made available or the timeout is reached. + pub fn park_timeout(&self, timeout: Duration) -> bool { + self.unparker.inner.park(Some(timeout)) + } + + /// Blocks the current thread until the token is made available or the deadline is reached. + pub fn park_deadline(&self, deadline: Instant) -> bool { + self.unparker + .inner + .park(Some(deadline.saturating_duration_since(Instant::now()))) + } + + /// Atomically makes the token available if it is not already. + pub fn unpark(&self) { + self.unparker.unpark() + } + + /// Returns a handle for unparking. + pub fn unparker(&self) -> Unparker { + self.unparker.clone() + } +} + +impl Drop for Parker { + fn drop(&mut self) { + PARKER_COUNT.fetch_sub(1, Ordering::SeqCst); + Reactor::get().thread_unparker.unpark(); + } +} + +impl fmt::Debug for Parker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Parker { .. }") + } +} + +/// Unparks a thread. +pub struct Unparker { + inner: Arc, +} + +impl Unparker { + /// Atomically makes the token available if it is not already. + pub fn unpark(&self) { + self.inner.unpark() + } +} + +impl fmt::Debug for Unparker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Unparker { .. }") + } +} + +impl Clone for Unparker { + fn clone(&self) -> Unparker { + Unparker { + inner: self.inner.clone(), + } + } +} + +const EMPTY: usize = 0; +const PARKED: usize = 1; +const POLLING: usize = 2; +const NOTIFIED: usize = 3; + +struct Inner { + state: AtomicUsize, + lock: Mutex<()>, + cvar: Condvar, +} + +impl Inner { + fn park(&self, timeout: Option) -> bool { + // If we were previously notified then we consume this notification and return quickly. + if self + .state + .compare_exchange(NOTIFIED, EMPTY, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + // Process available I/O events. + if let Some(reactor_lock) = Reactor::get().try_lock() { + let _ = reactor_lock.react(Some(Duration::from_secs(0))); + } + return true; + } + + // If the timeout is zero, then there is no need to actually block. + if let Some(dur) = timeout { + if dur == Duration::from_millis(0) { + // Process available I/O events. + if let Some(reactor_lock) = Reactor::get().try_lock() { + let _ = reactor_lock.react(Some(Duration::from_secs(0))); + } + return false; + } + } + + // Otherwise, we need to coordinate going to sleep. + let deadline = timeout.map(|t| Instant::now() + t); + loop { + let reactor_lock = Reactor::get().try_lock(); + + let state = match reactor_lock { + None => PARKED, + Some(_) => POLLING, + }; + let mut m = self.lock.lock().unwrap(); + + match self + .state + .compare_exchange(EMPTY, state, Ordering::SeqCst, Ordering::SeqCst) + { + Ok(_) => {} + // Consume this notification to avoid spurious wakeups in the next park. + Err(NOTIFIED) => { + // We must read `state` here, even though we know it will be `NOTIFIED`. This is + // because `unpark` may have been called again since we read `NOTIFIED` in the + // `compare_exchange` above. We must perform an acquire operation that synchronizes + // with that `unpark` to observe any writes it made before the call to `unpark`. To + // do that we must read from the write it made to `state`. + let old = self.state.swap(EMPTY, Ordering::SeqCst); + assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + return true; + } + Err(n) => panic!("inconsistent park_timeout state: {}", n), + } + + match deadline { + None => { + // Block the current thread on the conditional variable. + match reactor_lock { + None => m = self.cvar.wait(m).unwrap(), + Some(reactor_lock) => { + drop(m); + + let _ = reactor_lock.react(None); + + m = self.lock.lock().unwrap(); + } + } + + match self.state.swap(EMPTY, Ordering::SeqCst) { + NOTIFIED => return true, // got a notification + PARKED | POLLING => {} // spurious wakeup + n => panic!("inconsistent state: {}", n), + } + } + Some(deadline) => { + // Wait with a timeout, and if we spuriously wake up or otherwise wake up from a + // notification we just want to unconditionally set `state` back to `EMPTY`, either + // consuming a notification or un-flagging ourselves as parked. + let timeout = deadline.saturating_duration_since(Instant::now()); + + m = match reactor_lock { + None => self.cvar.wait_timeout(m, timeout).unwrap().0, + Some(reactor_lock) => { + drop(m); + let _ = reactor_lock.react(Some(timeout)); + self.lock.lock().unwrap() + } + }; + + match self.state.swap(EMPTY, Ordering::SeqCst) { + NOTIFIED => return true, // got a notification + PARKED | POLLING => {} // no notification + n => panic!("inconsistent state: {}", n), + } + + if Instant::now() >= deadline { + return false; + } + } + } + + drop(m); + } + } + + pub fn unpark(&self) { + // To ensure the unparked thread will observe any writes we made before this call, we must + // perform a release operation that `park` can synchronize with. To do that we must write + // `NOTIFIED` even if `state` is already `NOTIFIED`. That is why this must be a swap rather + // than a compare-and-swap that returns if it reads `NOTIFIED` on failure. + let state = match self.state.swap(NOTIFIED, Ordering::SeqCst) { + EMPTY => return, // no one was waiting + NOTIFIED => return, // already unparked + state => state, // gotta go wake someone up + }; + + // There is a period between when the parked thread sets `state` to `PARKED` (or last + // checked `state` in the case of a spurious wakeup) and when it actually waits on `cvar`. + // If we were to notify during this period it would be ignored and then when the parked + // thread went to sleep it would never wake up. Fortunately, it has `lock` locked at this + // stage so we can acquire `lock` to wait until it is ready to receive the notification. + // + // Releasing `lock` before the call to `notify_one` means that when the parked thread wakes + // it doesn't get woken only to have to wait for us to release `lock`. + drop(self.lock.lock().unwrap()); + + if state == PARKED { + self.cvar.notify_one(); + } else { + Reactor::get().notify(); + } + } +} + +/// The reactor. +/// +/// Every async I/O handle and every timer is registered here. Invocations of +/// [`run()`][`crate::run()`] poll the reactor to check for new events every now and then. +/// +/// There is only one global instance of this type, accessible by [`Reactor::get()`]. +pub(crate) struct Reactor { + thread_unparker: parking::Unparker, + + /// Raw bindings to epoll/kqueue/wepoll. + sys: sys::Reactor, + + /// Ticker bumped before polling. + ticker: AtomicUsize, + + /// Registered sources. + sources: Mutex>>, + + /// Temporary storage for I/O events when polling the reactor. + events: Mutex, + + /// An ordered map of registered timers. + /// + /// Timers are in the order in which they fire. The `usize` in this type is a timer ID used to + /// distinguish timers that fire at the same time. The `Waker` represents the task awaiting the + /// timer. + timers: Mutex>, + + /// A queue of timer operations (insert and remove). + /// + /// When inserting or removing a timer, we don't process it immediately - we just push it into + /// this queue. Timers actually get processed when the queue fills up or the reactor is polled. + timer_ops: ConcurrentQueue, +} + +impl Reactor { + /// Returns a reference to the reactor. + pub(crate) fn get() -> &'static Reactor { + static REACTOR: Lazy = Lazy::new(|| { + let (parker, unparker) = parking::pair(); + + thread::Builder::new() + .name("async-io".to_string()) + .spawn(move || { + let reactor = Reactor::get(); + let mut sleeps = 0u64; + let mut last_tick = 0; + + loop { + let tick = reactor.ticker.load(Ordering::SeqCst); + + if last_tick == tick { + let reactor_lock = if sleeps >= 60 { + Some(reactor.lock()) + } else { + reactor.try_lock() + }; + + if let Some(reactor_lock) = reactor_lock { + let _ = reactor_lock.react(None); + last_tick = reactor.ticker.load(Ordering::SeqCst); + } + + sleeps = 0; + } else { + last_tick = tick; + sleeps += 1; + } + + if PARKER_COUNT.load(Ordering::SeqCst) == 0 { + sleeps = 0; + } else { + let delay_us = if sleeps < 50 { + 20 + } else { + 20 << (sleeps - 50).min(9) + }; + + if parker.park_timeout(Duration::from_micros(delay_us)) { + sleeps = 0; + } + } + } + }) + .expect("cannot spawn async-io thread"); + + Reactor { + thread_unparker: unparker, + sys: sys::Reactor::new().expect("cannot initialize I/O event notification"), + ticker: AtomicUsize::new(0), + sources: Mutex::new(Slab::new()), + events: Mutex::new(sys::Events::new()), + timers: Mutex::new(BTreeMap::new()), + timer_ops: ConcurrentQueue::bounded(1000), + } + }); + &REACTOR + } + + /// Notifies the thread blocked on the reactor. + pub(crate) fn notify(&self) { + self.sys.notify().expect("failed to notify reactor"); + } + + /// Registers an I/O source in the reactor. + pub(crate) fn insert_io( + &self, + #[cfg(unix)] raw: RawFd, + #[cfg(windows)] raw: RawSocket, + ) -> io::Result> { + let mut sources = self.sources.lock().unwrap(); + let vacant = sources.vacant_entry(); + + // Create a source and register it. + let key = vacant.key(); + self.sys.register(raw, key)?; + + let source = Arc::new(Source { + raw, + key, + wakers: Mutex::new(Wakers { + tick_readable: 0, + tick_writable: 0, + readers: Vec::new(), + writers: Vec::new(), + }), + }); + Ok(vacant.insert(source).clone()) + } + + /// Deregisters an I/O source from the reactor. + pub(crate) fn remove_io(&self, source: &Source) -> io::Result<()> { + let mut sources = self.sources.lock().unwrap(); + sources.remove(source.key); + self.sys.deregister(source.raw) + } + + /// Registers a timer in the reactor. + /// + /// Returns the inserted timer's ID. + pub(crate) fn insert_timer(&self, when: Instant, waker: &Waker) -> usize { + // Generate a new timer ID. + static ID_GENERATOR: AtomicUsize = AtomicUsize::new(1); + let id = ID_GENERATOR.fetch_add(1, Ordering::Relaxed); + + // Push an insert operation. + while self + .timer_ops + .push(TimerOp::Insert(when, id, waker.clone())) + .is_err() + { + // If the queue is full, drain it and try again. + let mut timers = self.timers.lock().unwrap(); + self.process_timer_ops(&mut timers); + } + + // Notify that a timer has been inserted. + self.notify(); + + id + } + + /// Deregisters a timer from the reactor. + pub(crate) fn remove_timer(&self, when: Instant, id: usize) { + // Push a remove operation. + while self.timer_ops.push(TimerOp::Remove(when, id)).is_err() { + // If the queue is full, drain it and try again. + let mut timers = self.timers.lock().unwrap(); + self.process_timer_ops(&mut timers); + } + } + + /// Locks the reactor, potentially blocking if the lock is held by another thread. + fn lock(&self) -> ReactorLock<'_> { + let reactor = self; + let events = self.events.lock().unwrap(); + ReactorLock { reactor, events } + } + + /// Attempts to lock the reactor. + fn try_lock(&self) -> Option> { + self.events.try_lock().ok().map(|events| { + let reactor = self; + ReactorLock { reactor, events } + }) + } + + /// Processes ready timers and extends the list of wakers to wake. + /// + /// Returns the duration until the next timer before this method was called. + fn process_timers(&self, wakers: &mut Vec) -> Option { + let mut timers = self.timers.lock().unwrap(); + self.process_timer_ops(&mut timers); + + let now = Instant::now(); + + // Split timers into ready and pending timers. + let pending = timers.split_off(&(now, 0)); + let ready = mem::replace(&mut *timers, pending); + + // Calculate the duration until the next event. + let dur = if ready.is_empty() { + // Duration until the next timer. + timers + .keys() + .next() + .map(|(when, _)| when.saturating_duration_since(now)) + } else { + // Timers are about to fire right now. + Some(Duration::from_secs(0)) + }; + + // Drop the lock before waking. + drop(timers); + + // Add wakers to the list. + for (_, waker) in ready { + wakers.push(waker); + } + + dur + } + + /// Processes queued timer operations. + fn process_timer_ops(&self, timers: &mut MutexGuard<'_, BTreeMap<(Instant, usize), Waker>>) { + // Process only as much as fits into the queue, or else this loop could in theory run + // forever. + for _ in 0..self.timer_ops.capacity().unwrap() { + match self.timer_ops.pop() { + Ok(TimerOp::Insert(when, id, waker)) => { + timers.insert((when, id), waker); + } + Ok(TimerOp::Remove(when, id)) => { + timers.remove(&(when, id)); + } + Err(_) => break, + } + } + } +} + +/// A lock on the reactor. +struct ReactorLock<'a> { + reactor: &'a Reactor, + events: MutexGuard<'a, sys::Events>, +} + +impl ReactorLock<'_> { + /// Processes new events, blocking until the first event or the timeout. + fn react(mut self, timeout: Option) -> io::Result<()> { + let mut wakers = Vec::new(); + + // Process ready timers. + let next_timer = self.reactor.process_timers(&mut wakers); + + // compute the timeout for blocking on I/O events. + let timeout = match (next_timer, timeout) { + (None, None) => None, + (Some(t), None) | (None, Some(t)) => Some(t), + (Some(a), Some(b)) => Some(a.min(b)), + }; + + // Bump the ticker before polling I/O. + let tick = self + .reactor + .ticker + .fetch_add(1, Ordering::SeqCst) + .wrapping_add(1); + + // Block on I/O events. + let res = match self.reactor.sys.wait(&mut self.events, timeout) { + // No I/O events occurred. + Ok(0) => { + if timeout != Some(Duration::from_secs(0)) { + // The non-zero timeout was hit so fire ready timers. + self.reactor.process_timers(&mut wakers); + } + Ok(()) + } + + // At least one I/O event occurred. + Ok(_) => { + // Iterate over sources in the event list. + let sources = self.reactor.sources.lock().unwrap(); + + for ev in self.events.iter() { + // Check if there is a source in the table with this key. + if let Some(source) = sources.get(ev.key) { + let mut w = source.wakers.lock().unwrap(); + + // Wake readers if a readability event was emitted. + if ev.readable { + w.tick_readable = tick; + wakers.append(&mut w.readers); + } + + // Wake writers if a writability event was emitted. + if ev.writable { + w.tick_writable = tick; + wakers.append(&mut w.writers); + } + + // Re-register if there are still writers or + // readers. The can happen if e.g. we were + // previously interested in both readability and + // writability, but only one of them was emitted. + if !(w.writers.is_empty() && w.readers.is_empty()) { + self.reactor.sys.reregister( + source.raw, + source.key, + !w.readers.is_empty(), + !w.writers.is_empty(), + )?; + } + } + } + + Ok(()) + } + + // The syscall was interrupted. + Err(err) if err.kind() == io::ErrorKind::Interrupted => Ok(()), + + // An actual error occureed. + Err(err) => Err(err), + }; + + // Drop the lock before waking. + drop(self); + + // Wake up ready tasks. + for waker in wakers { + waker.wake(); + } + + res + } +} + +/// A single timer operation. +enum TimerOp { + Insert(Instant, usize, Waker), + Remove(Instant, usize), +} + +/// A registered source of I/O events. +#[derive(Debug)] +pub(crate) struct Source { + /// Raw file descriptor on Unix platforms. + #[cfg(unix)] + pub(crate) raw: RawFd, + + /// Raw socket handle on Windows. + #[cfg(windows)] + pub(crate) raw: RawSocket, + + /// The key of this source obtained during registration. + key: usize, + + /// Tasks interested in events on this source. + wakers: Mutex, +} + +/// Tasks interested in events on a source. +#[derive(Debug)] +struct Wakers { + /// Last reactor tick that delivered a readability event. + tick_readable: usize, + + /// Last reactor tick that delivered a writability event. + tick_writable: usize, + + /// Tasks waiting for the next readability event. + readers: Vec, + + /// Tasks waiting for the next writability event. + writers: Vec, +} + +impl Source { + /// Waits until the I/O source is readable. + pub(crate) async fn readable(&self) -> io::Result<()> { + let mut ticks = None; + + future::poll_fn(|cx| { + let mut w = self.wakers.lock().unwrap(); + + // Check if the reactor has delivered a readability event. + if let Some((a, b)) = ticks { + // If `tick_readable` has changed to a value other than the old reactor tick, that + // means a newer reactor tick has delivered a readability event. + if w.tick_readable != a && w.tick_readable != b { + return Poll::Ready(Ok(())); + } + } + + // If there are no other readers, re-register in the reactor. + if w.readers.is_empty() { + Reactor::get() + .sys + .reregister(self.raw, self.key, true, !w.writers.is_empty())?; + } + + // Register the current task's waker if not present already. + if w.readers.iter().all(|w| !w.will_wake(cx.waker())) { + w.readers.push(cx.waker().clone()); + } + + // Remember the current ticks. + if ticks.is_none() { + ticks = Some(( + Reactor::get().ticker.load(Ordering::SeqCst), + w.tick_readable, + )); + } + + Poll::Pending + }) + .await + } + + /// Waits until the I/O source is writable. + pub(crate) async fn writable(&self) -> io::Result<()> { + let mut ticks = None; + + future::poll_fn(|cx| { + let mut w = self.wakers.lock().unwrap(); + + // Check if the reactor has delivered a writability event. + if let Some((a, b)) = ticks { + // If `tick_writable` has changed to a value other than the old reactor tick, that + // means a newer reactor tick has delivered a writability event. + if w.tick_writable != a && w.tick_writable != b { + return Poll::Ready(Ok(())); + } + } + + // If there are no other writers, re-register in the reactor. + if w.writers.is_empty() { + Reactor::get() + .sys + .reregister(self.raw, self.key, !w.readers.is_empty(), true)?; + } + + // Register the current task's waker if not present already. + if w.writers.iter().all(|w| !w.will_wake(cx.waker())) { + w.writers.push(cx.waker().clone()); + } + + // Remember the current ticks. + if ticks.is_none() { + ticks = Some(( + Reactor::get().ticker.load(Ordering::SeqCst), + w.tick_writable, + )); + } + + Poll::Pending + }) + .await + } +} + +/// Raw bindings to epoll (Linux, Android, illumos). +#[cfg(any(target_os = "linux", target_os = "android", target_os = "illumos"))] +mod sys { + use std::convert::TryInto; + use std::io; + use std::os::unix::io::RawFd; + use std::time::Duration; + + use crate::sys::epoll::{ + epoll_create1, epoll_ctl, epoll_wait, EpollEvent, EpollFlags, EpollOp, + }; + + macro_rules! syscall { + ($fn:ident $args:tt) => {{ + let res = unsafe { libc::$fn $args }; + if res == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } + }}; + } + + pub struct Reactor { + epoll_fd: RawFd, + event_fd: RawFd, + } + impl Reactor { + pub fn new() -> io::Result { + let epoll_fd = epoll_create1()?; + let event_fd = syscall!(eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?; + let reactor = Reactor { epoll_fd, event_fd }; + reactor.register(event_fd, !0)?; + reactor.reregister(event_fd, !0, true, false)?; + Ok(reactor) + } + pub fn register(&self, fd: RawFd, key: usize) -> io::Result<()> { + let flags = syscall!(fcntl(fd, libc::F_GETFL))?; + syscall!(fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK))?; + let ev = &mut EpollEvent::new(0, key as u64); + epoll_ctl(self.epoll_fd, EpollOp::EpollCtlAdd, fd, Some(ev)) + } + pub fn reregister(&self, fd: RawFd, key: usize, read: bool, write: bool) -> io::Result<()> { + let mut flags = libc::EPOLLONESHOT; + if read { + flags |= read_flags(); + } + if write { + flags |= write_flags(); + } + let ev = &mut EpollEvent::new(flags, key as u64); + epoll_ctl(self.epoll_fd, EpollOp::EpollCtlMod, fd, Some(ev)) + } + pub fn deregister(&self, fd: RawFd) -> io::Result<()> { + epoll_ctl(self.epoll_fd, EpollOp::EpollCtlDel, fd, None) + } + pub fn wait(&self, events: &mut Events, timeout: Option) -> io::Result { + let timeout_ms = timeout + .map(|t| { + if t == Duration::from_millis(0) { + t + } else { + t.max(Duration::from_millis(1)) + } + }) + .and_then(|t| t.as_millis().try_into().ok()) + .unwrap_or(-1); + events.len = epoll_wait(self.epoll_fd, &mut events.list, timeout_ms)?; + + let mut buf = [0u8; 8]; + let _ = syscall!(read( + self.event_fd, + &mut buf[0] as *mut u8 as *mut libc::c_void, + buf.len() + )); + self.reregister(self.event_fd, !0, true, false)?; + + Ok(events.len) + } + pub fn notify(&self) -> io::Result<()> { + let buf: [u8; 8] = 1u64.to_ne_bytes(); + let _ = syscall!(write( + self.event_fd, + &buf[0] as *const u8 as *const libc::c_void, + buf.len() + )); + Ok(()) + } + } + fn read_flags() -> EpollFlags { + libc::EPOLLIN | libc::EPOLLRDHUP | libc::EPOLLHUP | libc::EPOLLERR | libc::EPOLLPRI + } + fn write_flags() -> EpollFlags { + libc::EPOLLOUT | libc::EPOLLHUP | libc::EPOLLERR + } + + pub struct Events { + list: Box<[EpollEvent]>, + len: usize, + } + impl Events { + pub fn new() -> Events { + let list = vec![EpollEvent::empty(); 1000].into_boxed_slice(); + let len = 0; + Events { list, len } + } + pub fn iter(&self) -> impl Iterator + '_ { + self.list[..self.len].iter().map(|ev| Event { + readable: (ev.events() & read_flags()) != 0, + writable: (ev.events() & write_flags()) != 0, + key: ev.data() as usize, + }) + } + } + pub struct Event { + pub readable: bool, + pub writable: bool, + pub key: usize, + } +} + +/// Raw bindings to kqueue (macOS, iOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD). +#[cfg(any( + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + target_os = "openbsd", + target_os = "dragonfly", +))] +mod sys { + use std::io::{self, Read, Write}; + use std::os::unix::io::{AsRawFd, RawFd}; + use std::os::unix::net::UnixStream; + use std::time::Duration; + + use crate::sys::event::{kevent_ts, kqueue, KEvent}; + + macro_rules! syscall { + ($fn:ident $args:tt) => {{ + let res = unsafe { libc::$fn $args }; + if res == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } + }}; + } + + pub struct Reactor { + kqueue_fd: RawFd, + read_stream: UnixStream, + write_stream: UnixStream, + } + impl Reactor { + pub fn new() -> io::Result { + let kqueue_fd = kqueue()?; + syscall!(fcntl(kqueue_fd, libc::F_SETFD, libc::FD_CLOEXEC))?; + let (read_stream, write_stream) = UnixStream::pair()?; + read_stream.set_nonblocking(true)?; + write_stream.set_nonblocking(true)?; + let reactor = Reactor { + kqueue_fd, + read_stream, + write_stream, + }; + reactor.reregister(reactor.read_stream.as_raw_fd(), !0, true, false)?; + Ok(reactor) + } + pub fn register(&self, fd: RawFd, _key: usize) -> io::Result<()> { + let flags = syscall!(fcntl(fd, libc::F_GETFL))?; + syscall!(fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK))?; + Ok(()) + } + pub fn reregister(&self, fd: RawFd, key: usize, read: bool, write: bool) -> io::Result<()> { + let mut read_flags = libc::EV_ONESHOT | libc::EV_RECEIPT; + let mut write_flags = libc::EV_ONESHOT | libc::EV_RECEIPT; + if read { + read_flags |= libc::EV_ADD; + } else { + read_flags |= libc::EV_DELETE; + } + if write { + write_flags |= libc::EV_ADD; + } else { + write_flags |= libc::EV_DELETE; + } + let udata = key as _; + let changelist = [ + KEvent::new(fd as _, libc::EVFILT_READ, read_flags, 0, 0, udata), + KEvent::new(fd as _, libc::EVFILT_WRITE, write_flags, 0, 0, udata), + ]; + let mut eventlist = changelist; + kevent_ts(self.kqueue_fd, &changelist, &mut eventlist, None)?; + for ev in &eventlist { + // Explanation for ignoring EPIPE: https://github.com/tokio-rs/mio/issues/582 + let (flags, data) = (ev.flags(), ev.data()); + if (flags & libc::EV_ERROR) == 1 + && data != 0 + && data != libc::ENOENT as _ + && data != libc::EPIPE as _ + { + return Err(io::Error::from_raw_os_error(data as _)); + } + } + Ok(()) + } + pub fn deregister(&self, fd: RawFd) -> io::Result<()> { + let flags = libc::EV_DELETE | libc::EV_RECEIPT; + let changelist = [ + KEvent::new(fd as _, libc::EVFILT_WRITE, flags, 0, 0, 0), + KEvent::new(fd as _, libc::EVFILT_READ, flags, 0, 0, 0), + ]; + let mut eventlist = changelist; + kevent_ts(self.kqueue_fd, &changelist, &mut eventlist, None)?; + for ev in &eventlist { + let (flags, data) = (ev.flags(), ev.data()); + if (flags & libc::EV_ERROR == 1) && data != 0 && data != libc::ENOENT as _ { + return Err(io::Error::from_raw_os_error(data as _)); + } + } + Ok(()) + } + pub fn wait(&self, events: &mut Events, timeout: Option) -> io::Result { + let timeout = timeout.map(|t| libc::timespec { + tv_sec: t.as_secs() as libc::time_t, + tv_nsec: t.subsec_nanos() as libc::c_long, + }); + events.len = kevent_ts(self.kqueue_fd, &[], &mut events.list, timeout)?; + + while (&self.read_stream).read(&mut [0; 64]).is_ok() {} + self.reregister(self.read_stream.as_raw_fd(), !0, true, false)?; + + Ok(events.len) + } + pub fn notify(&self) -> io::Result<()> { + let _ = (&self.write_stream).write(&[1]); + Ok(()) + } + } + + pub struct Events { + list: Box<[KEvent]>, + len: usize, + } + impl Events { + pub fn new() -> Events { + let flags = 0; + let event = KEvent::new(0, 0, flags, 0, 0, 0); + let list = vec![event; 1000].into_boxed_slice(); + let len = 0; + Events { list, len } + } + pub fn iter(&self) -> impl Iterator + '_ { + // On some platforms, closing the read end of a pipe wakes up writers, but the + // event is reported as EVFILT_READ with the EV_EOF flag. + // + // https://github.com/golang/go/commit/23aad448b1e3f7c3b4ba2af90120bde91ac865b4 + self.list[..self.len].iter().map(|ev| Event { + readable: ev.filter() == libc::EVFILT_READ, + writable: ev.filter() == libc::EVFILT_WRITE + || (ev.filter() == libc::EVFILT_READ && (ev.flags() & libc::EV_EOF) != 0), + key: ev.udata() as usize, + }) + } + } + pub struct Event { + pub readable: bool, + pub writable: bool, + pub key: usize, + } +} + +/// Raw bindings to wepoll (Windows). +#[cfg(target_os = "windows")] +mod sys { + use std::convert::TryInto; + use std::io; + use std::os::windows::io::{AsRawSocket, RawSocket}; + use std::time::Duration; + + use wepoll_sys_stjepang as we; + use winapi::um::winsock2; + + macro_rules! syscall { + ($fn:ident $args:tt) => {{ + let res = unsafe { we::$fn $args }; + if res == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } + }}; + } + + pub struct Reactor { + handle: we::HANDLE, + } + unsafe impl Send for Reactor {} + unsafe impl Sync for Reactor {} + impl Reactor { + pub fn new() -> io::Result { + let handle = unsafe { we::epoll_create1(0) }; + if handle.is_null() { + return Err(io::Error::last_os_error()); + } + Ok(Reactor { handle }) + } + pub fn register(&self, sock: RawSocket, key: usize) -> io::Result<()> { + unsafe { + let mut nonblocking = true as libc::c_ulong; + let res = winsock2::ioctlsocket( + sock as winsock2::SOCKET, + winsock2::FIONBIO, + &mut nonblocking, + ); + if res != 0 { + return Err(io::Error::last_os_error()); + } + } + let mut ev = we::epoll_event { + events: 0, + data: we::epoll_data { u64: key as u64 }, + }; + syscall!(epoll_ctl( + self.handle, + we::EPOLL_CTL_ADD as libc::c_int, + sock as we::SOCKET, + &mut ev, + ))?; + Ok(()) + } + pub fn reregister( + &self, + sock: RawSocket, + key: usize, + read: bool, + write: bool, + ) -> io::Result<()> { + let mut flags = we::EPOLLONESHOT; + if read { + flags |= READ_FLAGS; + } + if write { + flags |= WRITE_FLAGS; + } + let mut ev = we::epoll_event { + events: flags as u32, + data: we::epoll_data { u64: key as u64 }, + }; + syscall!(epoll_ctl( + self.handle, + we::EPOLL_CTL_MOD as libc::c_int, + sock as we::SOCKET, + &mut ev, + ))?; + Ok(()) + } + pub fn deregister(&self, sock: RawSocket) -> io::Result<()> { + syscall!(epoll_ctl( + self.handle, + we::EPOLL_CTL_DEL as libc::c_int, + sock as we::SOCKET, + 0 as *mut we::epoll_event, + ))?; + Ok(()) + } + pub fn wait(&self, events: &mut Events, timeout: Option) -> io::Result { + let timeout_ms = match timeout { + None => -1, + Some(t) => { + if t == Duration::from_millis(0) { + 0 + } else { + t.max(Duration::from_millis(1)) + .as_millis() + .try_into() + .unwrap_or(libc::c_int::max_value()) + } + } + }; + events.len = syscall!(epoll_wait( + self.handle, + events.list.as_mut_ptr(), + events.list.len() as libc::c_int, + timeout_ms, + ))? as usize; + Ok(events.len) + } + pub fn notify(&self) -> io::Result<()> { + unsafe { + // This errors if a notification has already been posted, but that's okay. + winapi::um::ioapiset::PostQueuedCompletionStatus( + self.handle as winapi::um::winnt::HANDLE, + 0, + 0, + 0 as *mut _, + ); + } + Ok(()) + } + } + struct As(RawSocket); + impl AsRawSocket for As { + fn as_raw_socket(&self) -> RawSocket { + self.0 + } + } + const READ_FLAGS: u32 = + we::EPOLLIN | we::EPOLLRDHUP | we::EPOLLHUP | we::EPOLLERR | we::EPOLLPRI; + const WRITE_FLAGS: u32 = we::EPOLLOUT | we::EPOLLHUP | we::EPOLLERR; + + pub struct Events { + list: Box<[we::epoll_event]>, + len: usize, + } + unsafe impl Send for Events {} + unsafe impl Sync for Events {} + impl Events { + pub fn new() -> Events { + let ev = we::epoll_event { + events: 0, + data: we::epoll_data { u64: 0 }, + }; + Events { + list: vec![ev; 1000].into_boxed_slice(), + len: 0, + } + } + pub fn iter(&self) -> impl Iterator + '_ { + self.list[..self.len].iter().map(|ev| Event { + readable: (ev.events & READ_FLAGS) != 0, + writable: (ev.events & WRITE_FLAGS) != 0, + key: unsafe { ev.data.u64 } as usize, + }) + } + } + pub struct Event { + pub readable: bool, + pub writable: bool, + pub key: usize, + } +} diff --git a/src/sys.rs b/src/sys.rs new file mode 100644 index 0000000..336ae90 --- /dev/null +++ b/src/sys.rs @@ -0,0 +1,277 @@ +#[cfg(unix)] +fn check_err(res: libc::c_int) -> Result { + if res == -1 { + return Err(std::io::Error::last_os_error()); + } + + Ok(res) +} + +#[cfg(any( + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "netbsd", + target_os = "openbsd", + target_os = "dragonfly", +))] +/// Kqueue. +pub mod event { + use super::check_err; + use std::os::unix::io::RawFd; + + #[cfg(any( + target_os = "macos", + target_os = "ios", + target_os = "freebsd", + target_os = "dragonfly", + target_os = "openbsd" + ))] + #[allow(non_camel_case_types)] + type type_of_nchanges = libc::c_int; + #[cfg(target_os = "netbsd")] + #[allow(non_camel_case_types)] + type type_of_nchanges = libc::size_t; + + #[cfg(target_os = "netbsd")] + #[allow(non_camel_case_types)] + type type_of_event_filter = u32; + #[cfg(not(target_os = "netbsd"))] + #[allow(non_camel_case_types)] + type type_of_event_filter = i16; + + #[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "ios", + target_os = "macos", + target_os = "openbsd" + ))] + #[allow(non_camel_case_types)] + type type_of_udata = *mut libc::c_void; + #[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "ios", + target_os = "macos" + ))] + #[allow(non_camel_case_types)] + type type_of_data = libc::intptr_t; + #[cfg(any(target_os = "netbsd"))] + #[allow(non_camel_case_types)] + type type_of_udata = libc::intptr_t; + #[cfg(any(target_os = "netbsd", target_os = "openbsd"))] + #[allow(non_camel_case_types)] + type type_of_data = libc::int64_t; + + #[derive(Clone, Copy)] + #[repr(C)] + pub struct KEvent(libc::kevent); + + unsafe impl Send for KEvent {} + + impl KEvent { + pub fn new( + ident: libc::uintptr_t, + filter: EventFilter, + flags: EventFlag, + fflags: FilterFlag, + data: libc::intptr_t, + udata: libc::intptr_t, + ) -> KEvent { + KEvent(libc::kevent { + ident, + filter: filter as type_of_event_filter, + flags, + fflags, + data: data as type_of_data, + udata: udata as type_of_udata, + }) + } + + pub fn filter(&self) -> EventFilter { + unsafe { std::mem::transmute(self.0.filter as type_of_event_filter) } + } + + pub fn flags(&self) -> EventFlag { + self.0.flags + } + + pub fn data(&self) -> libc::intptr_t { + self.0.data as libc::intptr_t + } + + pub fn udata(&self) -> libc::intptr_t { + self.0.udata as libc::intptr_t + } + } + + #[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "ios", + target_os = "macos", + target_os = "openbsd" + ))] + pub type EventFlag = u16; + #[cfg(any(target_os = "netbsd"))] + pub type EventFlag = u32; + + pub type FilterFlag = u32; + + #[cfg(target_os = "netbsd")] + pub type EventFilter = u32; + #[cfg(not(target_os = "netbsd"))] + pub type EventFilter = i16; + + pub fn kqueue() -> Result { + let res = unsafe { libc::kqueue() }; + + check_err(res) + } + + pub fn kevent_ts( + kq: RawFd, + changelist: &[KEvent], + eventlist: &mut [KEvent], + timeout_opt: Option, + ) -> Result { + let res = unsafe { + libc::kevent( + kq, + changelist.as_ptr() as *const libc::kevent, + changelist.len() as type_of_nchanges, + eventlist.as_mut_ptr() as *mut libc::kevent, + eventlist.len() as type_of_nchanges, + if let Some(ref timeout) = timeout_opt { + timeout as *const libc::timespec + } else { + std::ptr::null() + }, + ) + }; + + check_err(res).map(|r| r as usize) + } +} + +#[cfg(any(target_os = "linux", target_os = "android", target_os = "illumos"))] +/// Epoll. +pub mod epoll { + use super::check_err; + use std::os::unix::io::RawFd; + + #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] + #[repr(i32)] + pub enum EpollOp { + EpollCtlAdd = libc::EPOLL_CTL_ADD, + EpollCtlDel = libc::EPOLL_CTL_DEL, + EpollCtlMod = libc::EPOLL_CTL_MOD, + } + + pub type EpollFlags = libc::c_int; + + pub fn epoll_create1() -> Result { + // According to libuv, `EPOLL_CLOEXEC` is not defined on Android API < 21. + // But `EPOLL_CLOEXEC` is an alias for `O_CLOEXEC` on that platform, so we use it instead. + #[cfg(target_os = "android")] + const CLOEXEC: libc::c_int = libc::O_CLOEXEC; + #[cfg(not(target_os = "android"))] + const CLOEXEC: libc::c_int = libc::EPOLL_CLOEXEC; + + let fd = unsafe { + // Check if the `epoll_create1` symbol is available on this platform. + let ptr = libc::dlsym( + libc::RTLD_DEFAULT, + "epoll_create1\0".as_ptr() as *const libc::c_char, + ); + + if ptr.is_null() { + // If not, use `epoll_create` and manually set `CLOEXEC`. + let fd = check_err(libc::epoll_create(1024))?; + let flags = libc::fcntl(fd, libc::F_GETFD); + libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC); + fd + } else { + // Use `epoll_create1` with `CLOEXEC`. + let epoll_create1 = std::mem::transmute::< + *mut libc::c_void, + unsafe extern "C" fn(libc::c_int) -> libc::c_int, + >(ptr); + check_err(epoll_create1(CLOEXEC))? + } + }; + + Ok(fd) + } + + pub fn epoll_ctl<'a, T>( + epfd: RawFd, + op: EpollOp, + fd: RawFd, + event: T, + ) -> Result<(), std::io::Error> + where + T: Into>, + { + let mut event: Option<&mut EpollEvent> = event.into(); + if event.is_none() && op != EpollOp::EpollCtlDel { + Err(std::io::Error::from_raw_os_error(libc::EINVAL)) + } else { + let res = unsafe { + if let Some(ref mut event) = event { + libc::epoll_ctl(epfd, op as libc::c_int, fd, &mut event.event) + } else { + libc::epoll_ctl(epfd, op as libc::c_int, fd, std::ptr::null_mut()) + } + }; + check_err(res).map(drop) + } + } + + pub fn epoll_wait( + epfd: RawFd, + events: &mut [EpollEvent], + timeout_ms: isize, + ) -> Result { + let res = unsafe { + libc::epoll_wait( + epfd, + events.as_mut_ptr() as *mut libc::epoll_event, + events.len() as libc::c_int, + timeout_ms as libc::c_int, + ) + }; + + check_err(res).map(|r| r as usize) + } + + #[derive(Clone, Copy)] + #[repr(transparent)] + pub struct EpollEvent { + event: libc::epoll_event, + } + + impl EpollEvent { + pub fn new(events: EpollFlags, data: u64) -> Self { + EpollEvent { + event: libc::epoll_event { + events: events as u32, + u64: data, + }, + } + } + + pub fn empty() -> Self { + unsafe { std::mem::zeroed::() } + } + + pub fn events(&self) -> EpollFlags { + self.event.events as libc::c_int + } + + pub fn data(&self) -> u64 { + self.event.u64 + } + } +} diff --git a/tests/async.rs b/tests/async.rs new file mode 100644 index 0000000..ef60713 --- /dev/null +++ b/tests/async.rs @@ -0,0 +1,339 @@ +use std::future::Future; +use std::io; +use std::net::{Shutdown, TcpListener, TcpStream, UdpSocket}; +#[cfg(unix)] +use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +use async_io::{Async, Timer}; +use blocking::block_on; +use futures::{AsyncReadExt, AsyncWriteExt, StreamExt}; +#[cfg(unix)] +use tempfile::tempdir; + +const LOREM_IPSUM: &[u8] = b" +Lorem ipsum dolor sit amet, consectetur adipiscing elit. +Donec pretium ante erat, vitae sodales mi varius quis. +Etiam vestibulum lorem vel urna tempor, eu fermentum odio aliquam. +Aliquam consequat urna vitae ipsum pulvinar, in blandit purus eleifend. +"; + +fn spawn( + f: impl Future + Send + 'static, +) -> impl Future + Send + 'static { + let (s, r) = async_channel::bounded(1); + + thread::spawn(move || { + block_on(async { + let _ = s.send(f.await).await; + }) + }); + + Box::pin(async move { r.recv().await.unwrap() }) +} + +#[test] +fn tcp_connect() -> io::Result<()> { + block_on(async { + let listener = Async::::bind("127.0.0.1:0")?; + let addr = listener.get_ref().local_addr()?; + let task = spawn(async move { listener.accept().await }); + + let stream2 = Async::::connect(&addr).await?; + let stream1 = task.await?.0; + + assert_eq!( + stream1.get_ref().peer_addr()?, + stream2.get_ref().local_addr()?, + ); + assert_eq!( + stream2.get_ref().peer_addr()?, + stream1.get_ref().local_addr()?, + ); + + // Now that the listener is closed, connect should fail. + let err = Async::::connect(&addr).await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused); + + Ok(()) + }) +} + +#[test] +fn tcp_peek_read() -> io::Result<()> { + block_on(async { + let listener = Async::::bind("127.0.0.1:0")?; + let addr = listener.get_ref().local_addr()?; + + let mut stream = Async::::connect(addr).await?; + stream.write_all(LOREM_IPSUM).await?; + + let mut buf = [0; 1024]; + let mut incoming = listener.incoming(); + let mut stream = incoming.next().await.unwrap()?; + + let n = stream.peek(&mut buf).await?; + assert_eq!(&buf[..n], LOREM_IPSUM); + let n = stream.read(&mut buf).await?; + assert_eq!(&buf[..n], LOREM_IPSUM); + + Ok(()) + }) +} + +#[test] +fn tcp_reader_hangup() -> io::Result<()> { + block_on(async { + let listener = Async::::bind("127.0.0.1:0")?; + let addr = listener.get_ref().local_addr()?; + let task = spawn(async move { listener.accept().await }); + + let mut stream2 = Async::::connect(&addr).await?; + let stream1 = task.await?.0; + + let task = spawn(async move { + Timer::after(Duration::from_secs(1)).await; + drop(stream1); + }); + + while stream2.write_all(LOREM_IPSUM).await.is_ok() {} + task.await; + + Ok(()) + }) +} + +#[test] +fn tcp_writer_hangup() -> io::Result<()> { + block_on(async { + let listener = Async::::bind("127.0.0.1:0")?; + let addr = listener.get_ref().local_addr()?; + let task = spawn(async move { listener.accept().await }); + + let mut stream2 = Async::::connect(&addr).await?; + let stream1 = task.await?.0; + + let task = spawn(async move { + Timer::after(Duration::from_secs(1)).await; + drop(stream1); + }); + + let mut v = vec![]; + stream2.read_to_end(&mut v).await?; + assert!(v.is_empty()); + + task.await; + Ok(()) + }) +} + +#[test] +fn udp_send_recv() -> io::Result<()> { + block_on(async { + let socket1 = Async::::bind("127.0.0.1:0")?; + let socket2 = Async::::bind("127.0.0.1:0")?; + socket1.get_ref().connect(socket2.get_ref().local_addr()?)?; + + let mut buf = [0u8; 1024]; + + socket1.send(LOREM_IPSUM).await?; + let n = socket2.peek(&mut buf).await?; + assert_eq!(&buf[..n], LOREM_IPSUM); + let n = socket2.recv(&mut buf).await?; + assert_eq!(&buf[..n], LOREM_IPSUM); + + socket2 + .send_to(LOREM_IPSUM, socket1.get_ref().local_addr()?) + .await?; + let n = socket1.peek_from(&mut buf).await?.0; + assert_eq!(&buf[..n], LOREM_IPSUM); + let n = socket1.recv_from(&mut buf).await?.0; + assert_eq!(&buf[..n], LOREM_IPSUM); + + Ok(()) + }) +} + +#[cfg(unix)] +#[test] +fn udp_connect() -> io::Result<()> { + block_on(async { + let dir = tempdir()?; + let path = dir.path().join("socket"); + + let listener = Async::::bind(&path)?; + + let mut stream = Async::::connect(&path).await?; + stream.write_all(LOREM_IPSUM).await?; + + let mut buf = [0; 1024]; + let mut incoming = listener.incoming(); + let mut stream = incoming.next().await.unwrap()?; + + let n = stream.read(&mut buf).await?; + assert_eq!(&buf[..n], LOREM_IPSUM); + + Ok(()) + }) +} + +#[cfg(unix)] +#[test] +fn uds_connect() -> io::Result<()> { + block_on(async { + let dir = tempdir()?; + let path = dir.path().join("socket"); + let listener = Async::::bind(&path)?; + + let addr = listener.get_ref().local_addr()?; + let task = spawn(async move { listener.accept().await }); + + let stream2 = Async::::connect(addr.as_pathname().unwrap()).await?; + let stream1 = task.await?.0; + + assert_eq!( + stream1.get_ref().peer_addr()?.as_pathname(), + stream2.get_ref().local_addr()?.as_pathname(), + ); + assert_eq!( + stream2.get_ref().peer_addr()?.as_pathname(), + stream1.get_ref().local_addr()?.as_pathname(), + ); + + // Now that the listener is closed, connect should fail. + let err = Async::::connect(addr.as_pathname().unwrap()) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::ConnectionRefused); + + Ok(()) + }) +} + +#[cfg(unix)] +#[test] +fn uds_send_recv() -> io::Result<()> { + block_on(async { + let (socket1, socket2) = Async::::pair()?; + + socket1.send(LOREM_IPSUM).await?; + let mut buf = [0; 1024]; + let n = socket2.recv(&mut buf).await?; + assert_eq!(&buf[..n], LOREM_IPSUM); + + Ok(()) + }) +} + +#[cfg(unix)] +#[test] +fn uds_send_to_recv_from() -> io::Result<()> { + block_on(async { + let dir = tempdir()?; + let path = dir.path().join("socket"); + let socket1 = Async::::bind(&path)?; + let socket2 = Async::::unbound()?; + + socket2.send_to(LOREM_IPSUM, &path).await?; + let mut buf = [0; 1024]; + let n = socket1.recv_from(&mut buf).await?.0; + assert_eq!(&buf[..n], LOREM_IPSUM); + + Ok(()) + }) +} + +#[cfg(unix)] +#[test] +fn uds_reader_hangup() -> io::Result<()> { + block_on(async { + let (socket1, mut socket2) = Async::::pair()?; + + let task = spawn(async move { + Timer::after(Duration::from_secs(1)).await; + drop(socket1); + }); + + while socket2.write_all(LOREM_IPSUM).await.is_ok() {} + task.await; + + Ok(()) + }) +} + +#[cfg(unix)] +#[test] +fn uds_writer_hangup() -> io::Result<()> { + block_on(async { + let (socket1, mut socket2) = Async::::pair()?; + + let task = spawn(async move { + Timer::after(Duration::from_secs(1)).await; + drop(socket1); + }); + + let mut v = vec![]; + socket2.read_to_end(&mut v).await?; + assert!(v.is_empty()); + + task.await; + Ok(()) + }) +} + +// Test that we correctly re-register interests when we are previously +// interested in both readable and writable events and then we get only one of +// them. (we need to re-register interest on the other.) +#[test] +fn tcp_duplex() -> io::Result<()> { + block_on(async { + let listener = Async::::bind("127.0.0.1:0")?; + let stream0 = + Arc::new(Async::::connect(listener.get_ref().local_addr()?).await?); + let stream1 = Arc::new(listener.accept().await?.0); + + async fn do_read(s: Arc>) -> io::Result<()> { + let mut buf = vec![0u8; 4096]; + loop { + let len = (&*s).read(&mut buf).await?; + if len == 0 { + return Ok(()); + } + } + } + + async fn do_write(s: Arc>) -> io::Result<()> { + let buf = vec![0u8; 4096]; + for _ in 0..4096 { + (&*s).write_all(&buf).await?; + } + s.get_ref().shutdown(Shutdown::Write)?; + Ok(()) + } + + // Read from and write to stream0. + let r0 = spawn(do_read(stream0.clone())); + let w0 = spawn(do_write(stream0)); + + // Sleep a bit, so that reading and writing are both blocked. + Timer::after(Duration::from_millis(5)).await; + + // Start reading stream1, make stream0 writable. + let r1 = spawn(do_read(stream1.clone())); + + // Finish writing to stream0. + w0.await?; + r1.await?; + + // Start writing to stream1, make stream0 readable. + let w1 = spawn(do_write(stream1)); + + // Will r0 be correctly woken? + r0.await?; + w1.await?; + + Ok(()) + }) +} diff --git a/tests/timer.rs b/tests/timer.rs new file mode 100644 index 0000000..9fb04d9 --- /dev/null +++ b/tests/timer.rs @@ -0,0 +1,27 @@ +use std::time::{Duration, Instant}; + +use async_io::Timer; +use blocking::block_on; + +#[test] +fn timer_at() { + let before = block_on(async { + let now = Instant::now(); + let when = now + Duration::from_secs(1); + Timer::at(when).await; + now + }); + + assert!(before.elapsed() >= Duration::from_secs(1)); +} + +#[test] +fn timer_after() { + let before = block_on(async { + let now = Instant::now(); + Timer::after(Duration::from_secs(1)).await; + now + }); + + assert!(before.elapsed() >= Duration::from_secs(1)); +}