Replace some async blocks with manual futures (#34)

* Replace some async blocks with manual futures

* Fix WASM build

* Use a slightly more idiomatic way to access event listeners

* Fix MIRI failure

* Code review
This commit is contained in:
John Nunley 2022-12-28 20:30:48 -08:00 committed by GitHub
parent 61922eb271
commit 15049aa1c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 825 additions and 266 deletions

View File

@ -1,5 +1,12 @@
use event_listener::Event;
use event_listener::{Event, EventListener};
use futures_lite::ready;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::futures::Lock;
use crate::Mutex;
/// A counter to synchronize multiple tasks at the same time.
@ -72,24 +79,103 @@ impl Barrier {
/// });
/// }
/// ```
pub async fn wait(&self) -> BarrierWaitResult {
let mut state = self.state.lock().await;
let local_gen = state.generation_id;
state.count += 1;
pub fn wait(&self) -> BarrierWait<'_> {
BarrierWait {
barrier: self,
lock: Some(self.state.lock()),
state: WaitState::Initial,
}
}
}
if state.count < self.n {
while local_gen == state.generation_id && state.count < self.n {
let listener = self.event.listen();
drop(state);
listener.await;
state = self.state.lock().await;
/// The future returned by [`Barrier::wait()`].
pub struct BarrierWait<'a> {
/// The barrier to wait on.
barrier: &'a Barrier,
/// The ongoing mutex lock operation we are blocking on.
lock: Option<Lock<'a, State>>,
/// The current state of the future.
state: WaitState,
}
impl fmt::Debug for BarrierWait<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("BarrierWait { .. }")
}
}
enum WaitState {
/// We are getting the original values of the state.
Initial,
/// We are waiting for the listener to complete.
Waiting { evl: EventListener, local_gen: u64 },
/// Waiting to re-acquire the lock to check the state again.
Reacquiring(u64),
}
impl Future for BarrierWait<'_> {
type Output = BarrierWaitResult;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match this.state {
WaitState::Initial => {
// See if the lock is ready yet.
let mut state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx));
this.lock = None;
let local_gen = state.generation_id;
state.count += 1;
if state.count < this.barrier.n {
// We need to wait for the event.
this.state = WaitState::Waiting {
evl: this.barrier.event.listen(),
local_gen,
};
} else {
// We are the last one.
state.count = 0;
state.generation_id = state.generation_id.wrapping_add(1);
this.barrier.event.notify(std::usize::MAX);
return Poll::Ready(BarrierWaitResult { is_leader: true });
}
}
WaitState::Waiting {
ref mut evl,
local_gen,
} => {
ready!(Pin::new(evl).poll(cx));
// We are now re-acquiring the mutex.
this.lock = Some(this.barrier.state.lock());
this.state = WaitState::Reacquiring(local_gen);
}
WaitState::Reacquiring(local_gen) => {
// Acquire the local state again.
let state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx));
this.lock = None;
if local_gen == state.generation_id && state.count < this.barrier.n {
// We need to wait for the event again.
this.state = WaitState::Waiting {
evl: this.barrier.event.listen(),
local_gen,
};
} else {
// We are ready, but not the leader.
return Poll::Ready(BarrierWaitResult { is_leader: false });
}
}
}
BarrierWaitResult { is_leader: false }
} else {
state.count = 0;
state.generation_id = state.generation_id.wrapping_add(1);
self.event.notify(std::usize::MAX);
BarrierWaitResult { is_leader: true }
}
}
}

View File

@ -20,3 +20,12 @@ pub use mutex::{Mutex, MutexGuard, MutexGuardArc};
pub use once_cell::OnceCell;
pub use rwlock::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard};
pub use semaphore::{Semaphore, SemaphoreGuard, SemaphoreGuardArc};
pub mod futures {
//! Named futures for use with `async_lock` primitives.
pub use crate::barrier::BarrierWait;
pub use crate::mutex::{Lock, LockArc};
pub use crate::rwlock::{Read, UpgradableRead, Upgrade, Write};
pub use crate::semaphore::{Acquire, AcquireArc};
}

View File

@ -1,10 +1,15 @@
use std::borrow::Borrow;
use std::cell::UnsafeCell;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::process;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
// Note: we cannot use `target_family = "wasm"` here because it requires Rust 1.54.
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
@ -12,7 +17,8 @@ use std::time::{Duration, Instant};
use std::usize;
use event_listener::Event;
use event_listener::{Event, EventListener};
use futures_lite::ready;
/// An async mutex.
///
@ -103,114 +109,10 @@ impl<T: ?Sized> Mutex<T> {
/// # })
/// ```
#[inline]
pub async fn lock(&self) -> MutexGuard<'_, T> {
if let Some(guard) = self.try_lock() {
return guard;
}
self.acquire_slow().await;
MutexGuard(self)
}
/// Slow path for acquiring the mutex.
#[cold]
async fn acquire_slow(&self) {
// Get the current time.
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
let start = Instant::now();
loop {
// Start listening for events.
let listener = self.lock_ops.listen();
// Try locking if nobody is being starved.
match self
.state
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire)
.unwrap_or_else(|x| x)
{
// Lock acquired!
0 => return,
// Lock is held and nobody is starved.
1 => {}
// Somebody is starved.
_ => break,
}
// Wait for a notification.
listener.await;
// Try locking if nobody is being starved.
match self
.state
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire)
.unwrap_or_else(|x| x)
{
// Lock acquired!
0 => return,
// Lock is held and nobody is starved.
1 => {}
// Somebody is starved.
_ => {
// Notify the first listener in line because we probably received a
// notification that was meant for a starved task.
self.lock_ops.notify(1);
break;
}
}
// If waiting for too long, fall back to a fairer locking strategy that will prevent
// newer lock operations from starving us forever.
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
if start.elapsed() > Duration::from_micros(500) {
break;
}
}
// Increment the number of starved lock operations.
if self.state.fetch_add(2, Ordering::Release) > usize::MAX / 2 {
// In case of potential overflow, abort.
process::abort();
}
// Decrement the counter when exiting this function.
let _call = CallOnDrop(|| {
self.state.fetch_sub(2, Ordering::Release);
});
loop {
// Start listening for events.
let listener = self.lock_ops.listen();
// Try locking if nobody else is being starved.
match self
.state
.compare_exchange(2, 2 | 1, Ordering::Acquire, Ordering::Acquire)
.unwrap_or_else(|x| x)
{
// Lock acquired!
2 => return,
// Lock is held by someone.
s if s % 2 == 1 => {}
// Lock is available.
_ => {
// Be fair: notify the first listener and then go wait in line.
self.lock_ops.notify(1);
}
}
// Wait for a notification.
listener.await;
// Try acquiring the lock without waiting for others.
if self.state.fetch_or(1, Ordering::Acquire) % 2 == 0 {
return;
}
pub fn lock(&self) -> Lock<'_, T> {
Lock {
mutex: self,
acquire_slow: None,
}
}
@ -265,14 +167,6 @@ impl<T: ?Sized> Mutex<T> {
}
impl<T: ?Sized> Mutex<T> {
async fn lock_arc_impl(self: Arc<Self>) -> MutexGuardArc<T> {
if let Some(guard) = self.try_lock_arc() {
return guard;
}
self.acquire_slow().await;
MutexGuardArc(self)
}
/// Acquires the mutex and clones a reference to it.
///
/// Returns an owned guard that releases the mutex when dropped.
@ -290,8 +184,8 @@ impl<T: ?Sized> Mutex<T> {
/// # })
/// ```
#[inline]
pub fn lock_arc(self: &Arc<Self>) -> impl Future<Output = MutexGuardArc<T>> {
self.clone().lock_arc_impl()
pub fn lock_arc(self: &Arc<Self>) -> LockArc<T> {
LockArc(LockArcInnards::Unpolled(self.clone()))
}
/// Attempts to acquire the mutex and clone a reference to it.
@ -353,6 +247,295 @@ impl<T: Default + ?Sized> Default for Mutex<T> {
}
}
/// The future returned by [`Mutex::lock`].
pub struct Lock<'a, T: ?Sized> {
/// Reference to the mutex.
mutex: &'a Mutex<T>,
/// The future that waits for the mutex to become available.
acquire_slow: Option<AcquireSlow<&'a Mutex<T>, T>>,
}
impl<'a, T: ?Sized> Unpin for Lock<'a, T> {}
impl<T: ?Sized> fmt::Debug for Lock<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Lock { .. }")
}
}
impl<'a, T: ?Sized> Future for Lock<'a, T> {
type Output = MutexGuard<'a, T>;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match this.acquire_slow.as_mut() {
None => {
// Try the fast path before trying to register slowly.
match this.mutex.try_lock() {
Some(guard) => return Poll::Ready(guard),
None => {
this.acquire_slow = Some(AcquireSlow::new(this.mutex));
}
}
}
Some(acquire_slow) => {
// Continue registering slowly.
let value = ready!(Pin::new(acquire_slow).poll(cx));
return Poll::Ready(MutexGuard(value));
}
}
}
}
}
/// The future returned by [`Mutex::lock_arc`].
pub struct LockArc<T: ?Sized>(LockArcInnards<T>);
enum LockArcInnards<T: ?Sized> {
/// We have not tried to poll the fast path yet.
Unpolled(Arc<Mutex<T>>),
/// We are acquiring the mutex through the slow path.
AcquireSlow(AcquireSlow<Arc<Mutex<T>>, T>),
/// Empty hole to make taking easier.
Empty,
}
impl<T: ?Sized> Unpin for LockArc<T> {}
impl<T: ?Sized> fmt::Debug for LockArc<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("LockArc { .. }")
}
}
impl<T: ?Sized> Future for LockArc<T> {
type Output = MutexGuardArc<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match mem::replace(&mut this.0, LockArcInnards::Empty) {
LockArcInnards::Unpolled(mutex) => {
// Try the fast path before trying to register slowly.
match mutex.try_lock_arc() {
Some(guard) => return Poll::Ready(guard),
None => {
*this = LockArc(LockArcInnards::AcquireSlow(AcquireSlow::new(
mutex.clone(),
)));
}
}
}
LockArcInnards::AcquireSlow(mut acquire_slow) => {
// Continue registering slowly.
let value = match Pin::new(&mut acquire_slow).poll(cx) {
Poll::Pending => {
*this = LockArc(LockArcInnards::AcquireSlow(acquire_slow));
return Poll::Pending;
}
Poll::Ready(value) => value,
};
return Poll::Ready(MutexGuardArc(value));
}
LockArcInnards::Empty => panic!("future polled after completion"),
}
}
}
}
/// Future for acquiring the mutex slowly.
struct AcquireSlow<B: Borrow<Mutex<T>>, T: ?Sized> {
/// Reference to the mutex.
mutex: Option<B>,
/// The event listener waiting on the mutex.
listener: Option<EventListener>,
/// The point at which the mutex lock was started.
#[cfg(not(any(target_arch = "wasm32", target_os = "wasm64")))]
start: Option<Instant>,
/// This lock operation is starving.
starved: bool,
/// Capture the `T` lifetime.
_marker: PhantomData<T>,
}
impl<B: Borrow<Mutex<T>> + Unpin, T: ?Sized> Unpin for AcquireSlow<B, T> {}
impl<T: ?Sized, B: Borrow<Mutex<T>>> AcquireSlow<B, T> {
/// Create a new `AcquireSlow` future.
#[cold]
fn new(mutex: B) -> Self {
AcquireSlow {
mutex: Some(mutex),
listener: None,
#[cfg(not(any(target_arch = "wasm32", target_os = "wasm64")))]
start: None,
starved: false,
_marker: PhantomData,
}
}
/// Take the mutex reference out, decrementing the counter if necessary.
fn take_mutex(&mut self) -> Option<B> {
let mutex = self.mutex.take();
if self.starved {
if let Some(mutex) = mutex.as_ref() {
// Decrement this counter before we exit.
mutex.borrow().state.fetch_sub(2, Ordering::Release);
}
}
mutex
}
}
impl<T: ?Sized, B: Unpin + Borrow<Mutex<T>>> Future for AcquireSlow<B, T> {
type Output = B;
#[cold]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
let start = *this.start.get_or_insert_with(Instant::now);
let mutex = this
.mutex
.as_ref()
.expect("future polled after completion")
.borrow();
// Only use this hot loop if we aren't currently starved.
if !this.starved {
loop {
// Start listening for events.
match &mut this.listener {
listener @ None => {
// Start listening for events.
*listener = Some(mutex.lock_ops.listen());
// Try locking if nobody is being starved.
match mutex
.state
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire)
.unwrap_or_else(|x| x)
{
// Lock acquired!
0 => return Poll::Ready(this.take_mutex().unwrap()),
// Lock is held and nobody is starved.
1 => {}
// Somebody is starved.
_ => break,
}
}
Some(ref mut listener) => {
// Wait for a notification.
ready!(Pin::new(listener).poll(cx));
this.listener = None;
// Try locking if nobody is being starved.
match mutex
.state
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire)
.unwrap_or_else(|x| x)
{
// Lock acquired!
0 => return Poll::Ready(this.take_mutex().unwrap()),
// Lock is held and nobody is starved.
1 => {}
// Somebody is starved.
_ => {
// Notify the first listener in line because we probably received a
// notification that was meant for a starved task.
mutex.lock_ops.notify(1);
break;
}
}
// If waiting for too long, fall back to a fairer locking strategy that will prevent
// newer lock operations from starving us forever.
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
if start.elapsed() > Duration::from_micros(500) {
break;
}
}
}
}
// Increment the number of starved lock operations.
if mutex.state.fetch_add(2, Ordering::Release) > usize::MAX / 2 {
// In case of potential overflow, abort.
process::abort();
}
// Indicate that we are now starving and will use a fairer locking strategy.
this.starved = true;
}
// Fairer locking loop.
loop {
match &mut this.listener {
listener @ None => {
// Start listening for events.
*listener = Some(mutex.lock_ops.listen());
// Try locking if nobody else is being starved.
match mutex
.state
.compare_exchange(2, 2 | 1, Ordering::Acquire, Ordering::Acquire)
.unwrap_or_else(|x| x)
{
// Lock acquired!
2 => return Poll::Ready(this.take_mutex().unwrap()),
// Lock is held by someone.
s if s % 2 == 1 => {}
// Lock is available.
_ => {
// Be fair: notify the first listener and then go wait in line.
mutex.lock_ops.notify(1);
}
}
}
Some(ref mut listener) => {
// Wait for a notification.
ready!(Pin::new(listener).poll(cx));
this.listener = None;
// Try acquiring the lock without waiting for others.
if mutex.state.fetch_or(1, Ordering::Acquire) % 2 == 0 {
return Poll::Ready(this.take_mutex().unwrap());
}
}
}
}
}
}
impl<T: ?Sized, B: Borrow<Mutex<T>>> Drop for AcquireSlow<B, T> {
fn drop(&mut self) {
// Make sure the starvation counter is decremented.
self.take_mutex();
}
}
/// A guard that releases the mutex when dropped.
pub struct MutexGuard<'a, T: ?Sized>(&'a Mutex<T>);

View File

@ -1,12 +1,17 @@
use std::cell::UnsafeCell;
use std::fmt;
use std::future::Future;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::process;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use event_listener::Event;
use event_listener::{Event, EventListener};
use futures_lite::ready;
use crate::futures::Lock;
use crate::{Mutex, MutexGuard};
const WRITER_BIT: usize = 1;
@ -170,42 +175,11 @@ impl<T: ?Sized> RwLock<T> {
/// assert!(lock.try_read().is_some());
/// # })
/// ```
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
let mut state = self.state.load(Ordering::Acquire);
loop {
if state & WRITER_BIT == 0 {
// Make sure the number of readers doesn't overflow.
if state > std::isize::MAX as usize {
process::abort();
}
// If nobody is holding a write lock or attempting to acquire it, increment the
// number of readers.
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return RwLockReadGuard(self),
Err(s) => state = s,
}
} else {
// Start listening for "no writer" events.
let listener = self.no_writer.listen();
// Check again if there's a writer.
if self.state.load(Ordering::SeqCst) & WRITER_BIT != 0 {
// Wait until the writer is dropped.
listener.await;
// Notify the next reader waiting in line.
self.no_writer.notify(1);
}
// Reload the state.
state = self.state.load(Ordering::Acquire);
}
pub fn read(&self) -> Read<'_, T> {
Read {
lock: self,
state: self.state.load(Ordering::Acquire),
listener: None,
}
}
@ -289,33 +263,10 @@ impl<T: ?Sized> RwLock<T> {
/// *writer = 2;
/// # })
/// ```
pub async fn upgradable_read(&self) -> RwLockUpgradableReadGuard<'_, T> {
// First grab the mutex.
let lock = self.mutex.lock().await;
let mut state = self.state.load(Ordering::Acquire);
// Make sure the number of readers doesn't overflow.
if state > std::isize::MAX as usize {
process::abort();
}
// Increment the number of readers.
loop {
match self.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
return RwLockUpgradableReadGuard {
reader: RwLockReadGuard(self),
reserved: lock,
}
}
Err(s) => state = s,
}
pub fn upgradable_read(&self) -> UpgradableRead<'_, T> {
UpgradableRead {
lock: self,
acquire: self.mutex.lock(),
}
}
@ -372,30 +323,11 @@ impl<T: ?Sized> RwLock<T> {
/// assert!(lock.try_read().is_none());
/// # })
/// ```
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
// First grab the mutex.
let lock = self.mutex.lock().await;
// Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled.
self.state.fetch_or(WRITER_BIT, Ordering::SeqCst);
let guard = RwLockWriteGuard {
writer: RwLockWriteGuardInner(self),
reserved: lock,
};
// If there are readers, we need to wait for them to finish.
while self.state.load(Ordering::SeqCst) != WRITER_BIT {
// Start listening for "no readers" events.
let listener = self.no_readers.listen();
// Check again if there are readers.
if self.state.load(Ordering::Acquire) != WRITER_BIT {
// Wait for the readers to finish.
listener.await;
}
pub fn write(&self) -> Write<'_, T> {
Write {
lock: self,
state: WriteState::Acquiring(self.mutex.lock()),
}
guard
}
/// Returns a mutable reference to the inner value.
@ -448,6 +380,230 @@ impl<T: Default + ?Sized> Default for RwLock<T> {
}
}
/// The future returned by [`RwLock::read`].
pub struct Read<'a, T: ?Sized> {
/// The lock that is being acquired.
lock: &'a RwLock<T>,
/// The last-observed state of the lock.
state: usize,
/// The listener for the "no writers" event.
listener: Option<EventListener>,
}
impl<T: ?Sized> fmt::Debug for Read<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Read { .. }")
}
}
impl<T: ?Sized> Unpin for Read<'_, T> {}
impl<'a, T: ?Sized> Future for Read<'a, T> {
type Output = RwLockReadGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
if this.state & WRITER_BIT == 0 {
// Make sure the number of readers doesn't overflow.
if this.state > std::isize::MAX as usize {
process::abort();
}
// If nobody is holding a write lock or attempting to acquire it, increment the
// number of readers.
match this.lock.state.compare_exchange(
this.state,
this.state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Poll::Ready(RwLockReadGuard(this.lock)),
Err(s) => this.state = s,
}
} else {
// Start listening for "no writer" events.
let load_ordering = match &mut this.listener {
listener @ None => {
*listener = Some(this.lock.no_writer.listen());
// Make sure there really is no writer.
Ordering::SeqCst
}
Some(ref mut listener) => {
// Wait for the writer to finish.
ready!(Pin::new(listener).poll(cx));
this.listener = None;
// Notify the next reader waiting in list.
this.lock.no_writer.notify(1);
// Check the state again.
Ordering::Acquire
}
};
// Reload the state.
this.state = this.lock.state.load(load_ordering);
}
}
}
}
/// The future returned by [`RwLock::upgradable_read`].
pub struct UpgradableRead<'a, T: ?Sized> {
/// The lock that is being acquired.
lock: &'a RwLock<T>,
/// The mutex we are trying to acquire.
acquire: Lock<'a, ()>,
}
impl<T: ?Sized> fmt::Debug for UpgradableRead<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("UpgradableRead { .. }")
}
}
impl<T: ?Sized> Unpin for UpgradableRead<'_, T> {}
impl<'a, T: ?Sized> Future for UpgradableRead<'a, T> {
type Output = RwLockUpgradableReadGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
// Acquire the mutex.
let mutex_guard = ready!(Pin::new(&mut this.acquire).poll(cx));
let mut state = this.lock.state.load(Ordering::Acquire);
// Make sure the number of readers doesn't overflow.
if state > std::isize::MAX as usize {
process::abort();
}
// Increment the number of readers.
loop {
match this.lock.state.compare_exchange(
state,
state + ONE_READER,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
return Poll::Ready(RwLockUpgradableReadGuard {
reader: RwLockReadGuard(this.lock),
reserved: mutex_guard,
});
}
Err(s) => state = s,
}
}
}
}
/// The future returned by [`RwLock::write`].
pub struct Write<'a, T: ?Sized> {
/// The lock that is being acquired.
lock: &'a RwLock<T>,
/// Current state fof this future.
state: WriteState<'a, T>,
}
enum WriteState<'a, T: ?Sized> {
/// We are currently acquiring the inner mutex.
Acquiring(Lock<'a, ()>),
/// We are currently waiting for readers to finish.
WaitingReaders {
/// Our current write guard.
guard: Option<RwLockWriteGuard<'a, T>>,
/// The listener for the "no readers" event.
listener: Option<EventListener>,
},
}
impl<T: ?Sized> fmt::Debug for Write<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Write { .. }")
}
}
impl<T: ?Sized> Unpin for Write<'_, T> {}
impl<'a, T: ?Sized> Future for Write<'a, T> {
type Output = RwLockWriteGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match &mut this.state {
WriteState::Acquiring(lock) => {
// First grab the mutex.
let mutex_guard = ready!(Pin::new(lock).poll(cx));
// Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled.
let new_state = this.lock.state.fetch_or(WRITER_BIT, Ordering::SeqCst);
let guard = RwLockWriteGuard {
writer: RwLockWriteGuardInner(this.lock),
reserved: mutex_guard,
};
// If we just acquired the writer lock, return it.
if new_state == WRITER_BIT {
return Poll::Ready(guard);
}
// Start waiting for the readers to finish.
this.state = WriteState::WaitingReaders {
guard: Some(guard),
listener: Some(this.lock.no_readers.listen()),
};
}
WriteState::WaitingReaders {
guard,
ref mut listener,
} => {
let load_ordering = if listener.is_some() {
Ordering::Acquire
} else {
Ordering::SeqCst
};
// Check the state again.
if this.lock.state.load(load_ordering) == WRITER_BIT {
// We are the only ones holding the lock, return it.
return Poll::Ready(guard.take().unwrap());
}
// Wait for the readers to finish.
match listener {
None => {
// Register a listener.
*listener = Some(this.lock.no_readers.listen());
}
Some(ref mut evl) => {
// Wait for the readers to finish.
ready!(Pin::new(evl).poll(cx));
*listener = None;
}
};
}
}
}
}
}
/// A guard that releases the read lock when dropped.
pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RwLock<T>);
@ -585,7 +741,7 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> {
/// *writer = 2;
/// # })
/// ```
pub async fn upgrade(guard: Self) -> RwLockWriteGuard<'a, T> {
pub fn upgrade(guard: Self) -> Upgrade<'a, T> {
// Set `WRITER_BIT` and decrement the number of readers at the same time.
guard
.reader
@ -596,19 +752,10 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> {
// Convert into a write guard that unsets `WRITER_BIT` in case this future is canceled.
let guard = guard.into_writer();
// If there are readers, we need to wait for them to finish.
while guard.writer.0.state.load(Ordering::SeqCst) != WRITER_BIT {
// Start listening for "no readers" events.
let listener = guard.writer.0.no_readers.listen();
// Check again if there are readers.
if guard.writer.0.state.load(Ordering::Acquire) != WRITER_BIT {
// Wait for the readers to finish.
listener.await;
}
Upgrade {
guard: Some(guard),
listener: None,
}
guard
}
}
@ -632,6 +779,67 @@ impl<T: ?Sized> Deref for RwLockUpgradableReadGuard<'_, T> {
}
}
/// The future returned by [`RwLockUpgradableReadGuard::upgrade`].
pub struct Upgrade<'a, T: ?Sized> {
/// The guard that we are upgrading to.
guard: Option<RwLockWriteGuard<'a, T>>,
/// The event listener we are waiting on.
listener: Option<EventListener>,
}
impl<T: ?Sized> fmt::Debug for Upgrade<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Upgrade").finish()
}
}
impl<T: ?Sized> Unpin for Upgrade<'_, T> {}
impl<'a, T: ?Sized> Future for Upgrade<'a, T> {
type Output = RwLockWriteGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let guard = this
.guard
.as_mut()
.expect("cannot poll future after completion");
// If there are readers, we need to wait for them to finish.
loop {
let load_ordering = if this.listener.is_some() {
Ordering::Acquire
} else {
Ordering::SeqCst
};
// See if the number of readers is zero.
let state = guard.writer.0.state.load(load_ordering);
if state == WRITER_BIT {
break;
}
// If there are readers, wait for them to finish.
match &mut this.listener {
listener @ None => {
// Start listening for "no readers" events.
*listener = Some(guard.writer.0.no_readers.listen());
}
Some(ref mut listener) => {
// Wait for the readers to finish.
ready!(Pin::new(listener).poll(cx));
this.listener = None;
}
}
}
// We are done.
Poll::Ready(this.guard.take().unwrap())
}
}
struct RwLockWriteGuardInner<'a, T: ?Sized>(&'a RwLock<T>);
impl<T: ?Sized> Drop for RwLockWriteGuardInner<'_, T> {

View File

@ -1,8 +1,12 @@
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use event_listener::Event;
use event_listener::{Event, EventListener};
use futures_lite::ready;
/// A counter for limiting the number of concurrent operations.
#[derive(Debug)]
@ -80,18 +84,10 @@ impl Semaphore {
/// let guard = s.acquire().await;
/// # });
/// ```
pub async fn acquire(&self) -> SemaphoreGuard<'_> {
let mut listener = None;
loop {
if let Some(guard) = self.try_acquire() {
return guard;
}
match listener.take() {
None => listener = Some(self.event.listen()),
Some(l) => l.await,
}
pub fn acquire(&self) -> Acquire<'_> {
Acquire {
semaphore: self,
listener: None,
}
}
}
@ -136,21 +132,6 @@ impl Semaphore {
}
}
async fn acquire_arc_impl(self: Arc<Self>) -> SemaphoreGuardArc {
let mut listener = None;
loop {
if let Some(guard) = self.try_acquire_arc() {
return guard;
}
match listener.take() {
None => listener = Some(self.event.listen()),
Some(l) => l.await,
}
}
}
/// Waits for an owned permit for a concurrent operation.
///
/// Returns a guard that releases the permit when dropped.
@ -166,8 +147,100 @@ impl Semaphore {
/// let guard = s.acquire_arc().await;
/// # });
/// ```
pub fn acquire_arc(self: &Arc<Self>) -> impl Future<Output = SemaphoreGuardArc> {
self.clone().acquire_arc_impl()
pub fn acquire_arc(self: &Arc<Self>) -> AcquireArc {
AcquireArc {
semaphore: self.clone(),
listener: None,
}
}
}
/// The future returned by [`Semaphore::acquire`].
pub struct Acquire<'a> {
/// The semaphore being acquired.
semaphore: &'a Semaphore,
/// The listener waiting on the semaphore.
listener: Option<EventListener>,
}
impl fmt::Debug for Acquire<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Acquire { .. }")
}
}
impl Unpin for Acquire<'_> {}
impl<'a> Future for Acquire<'a> {
type Output = SemaphoreGuard<'a>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match this.semaphore.try_acquire() {
Some(guard) => return Poll::Ready(guard),
None => {
// Wait on the listener.
match &mut this.listener {
listener @ None => {
*listener = Some(this.semaphore.event.listen());
}
Some(ref mut listener) => {
ready!(Pin::new(listener).poll(cx));
this.listener = None;
}
}
}
}
}
}
}
/// The future returned by [`Semaphore::acquire_arc`].
pub struct AcquireArc {
/// The semaphore being acquired.
semaphore: Arc<Semaphore>,
/// The listener waiting on the semaphore.
listener: Option<EventListener>,
}
impl fmt::Debug for AcquireArc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("AcquireArc { .. }")
}
}
impl Unpin for AcquireArc {}
impl Future for AcquireArc {
type Output = SemaphoreGuardArc;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match this.semaphore.try_acquire_arc() {
Some(guard) => {
this.listener = None;
return Poll::Ready(guard);
}
None => {
// Wait on the listener.
match &mut this.listener.take() {
listener @ None => {
*listener = Some(this.semaphore.event.listen());
}
Some(ref mut listener) => {
ready!(Pin::new(listener).poll(cx));
this.listener = None;
}
}
}
}
}
}
}