mirror of https://github.com/smol-rs/async-dup
505 lines
13 KiB
Rust
505 lines
13 KiB
Rust
//! Duplicate an async I/O handle.
|
|
//!
|
|
//! This crate provides two tools, [`Arc`] and [`Mutex`]:
|
|
//!
|
|
//! * [`Arc`] implements [`AsyncRead`], [`AsyncWrite`], and [`AsyncSeek`] if a reference to the
|
|
//! inner type does.
|
|
//! * A reference to [`Mutex`] implements [`AsyncRead`], [`AsyncWrite`], and [`AsyncSeek`] if the
|
|
//! inner type does.
|
|
//!
|
|
//! Wrap an async I/O handle in [`Arc`] or [`Mutex`] to clone it or share among tasks.
|
|
//!
|
|
//! # Examples
|
|
//!
|
|
//! Clone an async I/O handle:
|
|
//!
|
|
//! ```no_run
|
|
//! use async_dup::Arc;
|
|
//! use futures::io;
|
|
//! use smol::Async;
|
|
//! use std::net::TcpStream;
|
|
//!
|
|
//! # fn main() -> std::io::Result<()> { smol::block_on(async {
|
|
//! // A client that echoes messages back to the server.
|
|
//! let stream = Async::<TcpStream>::connect(([127, 0, 0, 1], 8000)).await?;
|
|
//!
|
|
//! // Create two handles to the stream.
|
|
//! let reader = Arc::new(stream);
|
|
//! let mut writer = reader.clone();
|
|
//!
|
|
//! // Echo data received from the reader back into the writer.
|
|
//! io::copy(reader, &mut writer).await?;
|
|
//! # Ok(()) }) }
|
|
//! ```
|
|
//!
|
|
//! Share an async I/O handle:
|
|
//!
|
|
//! ```
|
|
//! use async_dup::Mutex;
|
|
//! use futures::io;
|
|
//! use futures::prelude::*;
|
|
//!
|
|
//! // Reads data from a stream and echoes it back.
|
|
//! async fn echo(stream: impl AsyncRead + AsyncWrite + Unpin) -> io::Result<u64> {
|
|
//! let stream = Mutex::new(stream);
|
|
//! io::copy(&stream, &mut &stream).await
|
|
//! }
|
|
//! ```
|
|
|
|
#![forbid(unsafe_code)]
|
|
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
|
|
#![doc(
|
|
html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
|
|
)]
|
|
#![doc(
|
|
html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
|
|
)]
|
|
|
|
use std::fmt;
|
|
use std::hash::{Hash, Hasher};
|
|
use std::io::{self, IoSlice, IoSliceMut, SeekFrom};
|
|
use std::ops::{Deref, DerefMut};
|
|
use std::pin::Pin;
|
|
use std::task::{Context, Poll};
|
|
|
|
use futures_io::{AsyncRead, AsyncSeek, AsyncWrite};
|
|
|
|
/// A reference-counted pointer that implements async I/O traits.
|
|
///
|
|
/// This is just a wrapper around [`std::sync::Arc`] that adds the following impls:
|
|
///
|
|
/// - `impl<T> AsyncRead for Arc<T> where &T: AsyncRead {}`
|
|
/// - `impl<T> AsyncWrite for Arc<T> where &T: AsyncWrite {}`
|
|
/// - `impl<T> AsyncSeek for Arc<T> where &T: AsyncSeek {}`
|
|
pub struct Arc<T>(pub std::sync::Arc<T>);
|
|
|
|
impl<T> Unpin for Arc<T> {}
|
|
|
|
impl<T> Arc<T> {
|
|
/// Constructs a new `Arc<T>`.
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use async_dup::Arc;
|
|
///
|
|
/// let a = Arc::new(7);
|
|
/// ```
|
|
pub fn new(data: T) -> Arc<T> {
|
|
Arc(std::sync::Arc::new(data))
|
|
}
|
|
}
|
|
|
|
impl<T> Clone for Arc<T> {
|
|
fn clone(&self) -> Arc<T> {
|
|
Arc(self.0.clone())
|
|
}
|
|
}
|
|
|
|
impl<T> Deref for Arc<T> {
|
|
type Target = T;
|
|
|
|
#[inline]
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl<T: fmt::Debug> fmt::Debug for Arc<T> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
fmt::Debug::fmt(&**self, f)
|
|
}
|
|
}
|
|
|
|
impl<T: fmt::Display> fmt::Display for Arc<T> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
fmt::Display::fmt(&**self, f)
|
|
}
|
|
}
|
|
|
|
impl<T: Hash> Hash for Arc<T> {
|
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
|
(**self).hash(state)
|
|
}
|
|
}
|
|
|
|
impl<T> fmt::Pointer for Arc<T> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
fmt::Pointer::fmt(&(&**self as *const T), f)
|
|
}
|
|
}
|
|
|
|
impl<T: Default> Default for Arc<T> {
|
|
fn default() -> Arc<T> {
|
|
Arc::new(Default::default())
|
|
}
|
|
}
|
|
|
|
impl<T> From<T> for Arc<T> {
|
|
fn from(t: T) -> Arc<T> {
|
|
Arc::new(t)
|
|
}
|
|
}
|
|
|
|
// NOTE(stjepang): It would also make sense to have the following impls:
|
|
//
|
|
// - `impl<T> AsyncRead for &Arc<T> where &T: AsyncRead {}`
|
|
// - `impl<T> AsyncWrite for &Arc<T> where &T: AsyncWrite {}`
|
|
// - `impl<T> AsyncSeek for &Arc<T> where &T: AsyncSeek {}`
|
|
//
|
|
// However, those impls sometimes make Rust's type inference try too hard when types cannot be
|
|
// inferred. In the end, instead of complaining with a nice error message, the Rust compiler ends
|
|
// up overflowing and dumping a very long error message spanning multiple screens.
|
|
//
|
|
// Since those impls are not essential, I decided to err on the safe side and not include them.
|
|
|
|
impl<T> AsyncRead for Arc<T>
|
|
where
|
|
for<'a> &'a T: AsyncRead,
|
|
{
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut [u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut &*self.0).poll_read(cx, buf)
|
|
}
|
|
|
|
fn poll_read_vectored(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
bufs: &mut [IoSliceMut<'_>],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut &*self.0).poll_read_vectored(cx, bufs)
|
|
}
|
|
}
|
|
|
|
impl<T> AsyncWrite for Arc<T>
|
|
where
|
|
for<'a> &'a T: AsyncWrite,
|
|
{
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut &*self.0).poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_write_vectored(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
bufs: &[IoSlice<'_>],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut &*self.0).poll_write_vectored(cx, bufs)
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut &*self.0).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut &*self.0).poll_close(cx)
|
|
}
|
|
}
|
|
|
|
impl<T> AsyncSeek for Arc<T>
|
|
where
|
|
for<'a> &'a T: AsyncSeek,
|
|
{
|
|
fn poll_seek(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
pos: SeekFrom,
|
|
) -> Poll<io::Result<u64>> {
|
|
Pin::new(&mut &*self.0).poll_seek(cx, pos)
|
|
}
|
|
}
|
|
|
|
/// A mutex that implements async I/O traits.
|
|
///
|
|
/// This is a blocking mutex that adds the following impls:
|
|
///
|
|
/// - `impl<T> AsyncRead for Mutex<T> where T: AsyncRead + Unpin {}`
|
|
/// - `impl<T> AsyncRead for &Mutex<T> where T: AsyncRead + Unpin {}`
|
|
/// - `impl<T> AsyncWrite for Mutex<T> where T: AsyncWrite + Unpin {}`
|
|
/// - `impl<T> AsyncWrite for &Mutex<T> where T: AsyncWrite + Unpin {}`
|
|
/// - `impl<T> AsyncSeek for Mutex<T> where T: AsyncSeek + Unpin {}`
|
|
/// - `impl<T> AsyncSeek for &Mutex<T> where T: AsyncSeek + Unpin {}`
|
|
pub struct Mutex<T>(async_lock::Mutex<T>);
|
|
|
|
impl<T> Mutex<T> {
|
|
/// Creates a new mutex.
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use async_dup::Mutex;
|
|
///
|
|
/// let mutex = Mutex::new(10);
|
|
/// ```
|
|
pub fn new(data: T) -> Mutex<T> {
|
|
Mutex(data.into())
|
|
}
|
|
|
|
/// Acquires the mutex, blocking the current thread until it is able to do so.
|
|
///
|
|
/// Returns a guard that releases the mutex when dropped.
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use async_dup::Mutex;
|
|
///
|
|
/// let mutex = Mutex::new(10);
|
|
/// let guard = mutex.lock();
|
|
/// assert_eq!(*guard, 10);
|
|
/// ```
|
|
pub fn lock(&self) -> MutexGuard<'_, T> {
|
|
MutexGuard(self.0.lock_blocking())
|
|
}
|
|
|
|
/// Attempts to acquire the mutex.
|
|
///
|
|
/// If the mutex could not be acquired at this time, then [`None`] is returned. Otherwise, a
|
|
/// guard is returned that releases the mutex when dropped.
|
|
///
|
|
/// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use async_dup::Mutex;
|
|
///
|
|
/// let mutex = Mutex::new(10);
|
|
/// if let Some(guard) = mutex.try_lock() {
|
|
/// assert_eq!(*guard, 10);
|
|
/// }
|
|
/// # ;
|
|
/// ```
|
|
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
|
|
self.0.try_lock().map(MutexGuard)
|
|
}
|
|
|
|
/// Consumes the mutex, returning the underlying data.
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use async_dup::Mutex;
|
|
///
|
|
/// let mutex = Mutex::new(10);
|
|
/// assert_eq!(mutex.into_inner(), 10);
|
|
/// ```
|
|
pub fn into_inner(self) -> T {
|
|
self.0.into_inner()
|
|
}
|
|
|
|
/// Returns a mutable reference to the underlying data.
|
|
///
|
|
/// Since this call borrows the mutex mutably, no actual locking takes place -- the mutable
|
|
/// borrow statically guarantees the mutex is not already acquired.
|
|
///
|
|
/// # Examples
|
|
///
|
|
/// ```
|
|
/// use async_dup::Mutex;
|
|
///
|
|
/// let mut mutex = Mutex::new(0);
|
|
/// *mutex.get_mut() = 10;
|
|
/// assert_eq!(*mutex.lock(), 10);
|
|
/// ```
|
|
pub fn get_mut(&mut self) -> &mut T {
|
|
self.0.get_mut()
|
|
}
|
|
}
|
|
|
|
impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
struct Locked;
|
|
impl fmt::Debug for Locked {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.write_str("<locked>")
|
|
}
|
|
}
|
|
|
|
match self.try_lock() {
|
|
None => f.debug_struct("Mutex").field("data", &Locked).finish(),
|
|
Some(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> From<T> for Mutex<T> {
|
|
fn from(val: T) -> Mutex<T> {
|
|
Mutex::new(val)
|
|
}
|
|
}
|
|
|
|
impl<T: Default> Default for Mutex<T> {
|
|
fn default() -> Mutex<T> {
|
|
Mutex::new(Default::default())
|
|
}
|
|
}
|
|
|
|
impl<T: AsyncRead + Unpin> AsyncRead for Mutex<T> {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut [u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut *self.lock()).poll_read(cx, buf)
|
|
}
|
|
|
|
fn poll_read_vectored(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
bufs: &mut [IoSliceMut<'_>],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut *self.lock()).poll_read_vectored(cx, bufs)
|
|
}
|
|
}
|
|
|
|
impl<T: AsyncRead + Unpin> AsyncRead for &Mutex<T> {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut [u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut *self.lock()).poll_read(cx, buf)
|
|
}
|
|
|
|
fn poll_read_vectored(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
bufs: &mut [IoSliceMut<'_>],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut *self.lock()).poll_read_vectored(cx, bufs)
|
|
}
|
|
}
|
|
|
|
impl<T: AsyncWrite + Unpin> AsyncWrite for Mutex<T> {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut *self.lock()).poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_write_vectored(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
bufs: &[IoSlice<'_>],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut *self.lock()).poll_write_vectored(cx, bufs)
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut *self.lock()).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut *self.lock()).poll_close(cx)
|
|
}
|
|
}
|
|
|
|
impl<T: AsyncWrite + Unpin> AsyncWrite for &Mutex<T> {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut *self.lock()).poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_write_vectored(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
bufs: &[IoSlice<'_>],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut *self.lock()).poll_write_vectored(cx, bufs)
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut *self.lock()).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut *self.lock()).poll_close(cx)
|
|
}
|
|
}
|
|
|
|
impl<T: AsyncSeek + Unpin> AsyncSeek for Mutex<T> {
|
|
fn poll_seek(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
pos: SeekFrom,
|
|
) -> Poll<io::Result<u64>> {
|
|
Pin::new(&mut *self.lock()).poll_seek(cx, pos)
|
|
}
|
|
}
|
|
|
|
impl<T: AsyncSeek + Unpin> AsyncSeek for &Mutex<T> {
|
|
fn poll_seek(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
pos: SeekFrom,
|
|
) -> Poll<io::Result<u64>> {
|
|
Pin::new(&mut *self.lock()).poll_seek(cx, pos)
|
|
}
|
|
}
|
|
|
|
/// A guard that releases the mutex when dropped.
|
|
pub struct MutexGuard<'a, T>(async_lock::MutexGuard<'a, T>);
|
|
|
|
impl<T: fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
fmt::Debug::fmt(&**self, f)
|
|
}
|
|
}
|
|
|
|
impl<T: fmt::Display> fmt::Display for MutexGuard<'_, T> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
(**self).fmt(f)
|
|
}
|
|
}
|
|
|
|
impl<T> Deref for MutexGuard<'_, T> {
|
|
type Target = T;
|
|
|
|
fn deref(&self) -> &T {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl<T> DerefMut for MutexGuard<'_, T> {
|
|
fn deref_mut(&mut self) -> &mut T {
|
|
&mut self.0
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn is_send<T: Send>(_: &T) {}
|
|
fn is_sync<T: Sync>(_: &T) {}
|
|
|
|
#[test]
|
|
fn is_send_sync() {
|
|
let arc = Arc::new(());
|
|
let mutex = Mutex::new(());
|
|
|
|
is_send(&arc);
|
|
is_sync(&arc);
|
|
|
|
is_send(&mutex);
|
|
is_sync(&mutex);
|
|
|
|
let guard = mutex.lock();
|
|
is_send(&guard);
|
|
is_sync(&guard);
|
|
}
|
|
}
|