mirror of https://github.com/smol-rs/async-lock
Replace some async blocks with manual futures (#34)
* Replace some async blocks with manual futures * Fix WASM build * Use a slightly more idiomatic way to access event listeners * Fix MIRI failure * Code review
This commit is contained in:
parent
61922eb271
commit
15049aa1c9
120
src/barrier.rs
120
src/barrier.rs
|
@ -1,5 +1,12 @@
|
|||
use event_listener::Event;
|
||||
use event_listener::{Event, EventListener};
|
||||
use futures_lite::ready;
|
||||
|
||||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use crate::futures::Lock;
|
||||
use crate::Mutex;
|
||||
|
||||
/// A counter to synchronize multiple tasks at the same time.
|
||||
|
@ -72,24 +79,103 @@ impl Barrier {
|
|||
/// });
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn wait(&self) -> BarrierWaitResult {
|
||||
let mut state = self.state.lock().await;
|
||||
let local_gen = state.generation_id;
|
||||
state.count += 1;
|
||||
pub fn wait(&self) -> BarrierWait<'_> {
|
||||
BarrierWait {
|
||||
barrier: self,
|
||||
lock: Some(self.state.lock()),
|
||||
state: WaitState::Initial,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state.count < self.n {
|
||||
while local_gen == state.generation_id && state.count < self.n {
|
||||
let listener = self.event.listen();
|
||||
drop(state);
|
||||
listener.await;
|
||||
state = self.state.lock().await;
|
||||
/// The future returned by [`Barrier::wait()`].
|
||||
pub struct BarrierWait<'a> {
|
||||
/// The barrier to wait on.
|
||||
barrier: &'a Barrier,
|
||||
|
||||
/// The ongoing mutex lock operation we are blocking on.
|
||||
lock: Option<Lock<'a, State>>,
|
||||
|
||||
/// The current state of the future.
|
||||
state: WaitState,
|
||||
}
|
||||
|
||||
impl fmt::Debug for BarrierWait<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("BarrierWait { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
enum WaitState {
|
||||
/// We are getting the original values of the state.
|
||||
Initial,
|
||||
|
||||
/// We are waiting for the listener to complete.
|
||||
Waiting { evl: EventListener, local_gen: u64 },
|
||||
|
||||
/// Waiting to re-acquire the lock to check the state again.
|
||||
Reacquiring(u64),
|
||||
}
|
||||
|
||||
impl Future for BarrierWait<'_> {
|
||||
type Output = BarrierWaitResult;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match this.state {
|
||||
WaitState::Initial => {
|
||||
// See if the lock is ready yet.
|
||||
let mut state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx));
|
||||
this.lock = None;
|
||||
|
||||
let local_gen = state.generation_id;
|
||||
state.count += 1;
|
||||
|
||||
if state.count < this.barrier.n {
|
||||
// We need to wait for the event.
|
||||
this.state = WaitState::Waiting {
|
||||
evl: this.barrier.event.listen(),
|
||||
local_gen,
|
||||
};
|
||||
} else {
|
||||
// We are the last one.
|
||||
state.count = 0;
|
||||
state.generation_id = state.generation_id.wrapping_add(1);
|
||||
this.barrier.event.notify(std::usize::MAX);
|
||||
return Poll::Ready(BarrierWaitResult { is_leader: true });
|
||||
}
|
||||
}
|
||||
|
||||
WaitState::Waiting {
|
||||
ref mut evl,
|
||||
local_gen,
|
||||
} => {
|
||||
ready!(Pin::new(evl).poll(cx));
|
||||
|
||||
// We are now re-acquiring the mutex.
|
||||
this.lock = Some(this.barrier.state.lock());
|
||||
this.state = WaitState::Reacquiring(local_gen);
|
||||
}
|
||||
|
||||
WaitState::Reacquiring(local_gen) => {
|
||||
// Acquire the local state again.
|
||||
let state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx));
|
||||
this.lock = None;
|
||||
|
||||
if local_gen == state.generation_id && state.count < this.barrier.n {
|
||||
// We need to wait for the event again.
|
||||
this.state = WaitState::Waiting {
|
||||
evl: this.barrier.event.listen(),
|
||||
local_gen,
|
||||
};
|
||||
} else {
|
||||
// We are ready, but not the leader.
|
||||
return Poll::Ready(BarrierWaitResult { is_leader: false });
|
||||
}
|
||||
}
|
||||
}
|
||||
BarrierWaitResult { is_leader: false }
|
||||
} else {
|
||||
state.count = 0;
|
||||
state.generation_id = state.generation_id.wrapping_add(1);
|
||||
self.event.notify(std::usize::MAX);
|
||||
BarrierWaitResult { is_leader: true }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,3 +20,12 @@ pub use mutex::{Mutex, MutexGuard, MutexGuardArc};
|
|||
pub use once_cell::OnceCell;
|
||||
pub use rwlock::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard};
|
||||
pub use semaphore::{Semaphore, SemaphoreGuard, SemaphoreGuardArc};
|
||||
|
||||
pub mod futures {
|
||||
//! Named futures for use with `async_lock` primitives.
|
||||
|
||||
pub use crate::barrier::BarrierWait;
|
||||
pub use crate::mutex::{Lock, LockArc};
|
||||
pub use crate::rwlock::{Read, UpgradableRead, Upgrade, Write};
|
||||
pub use crate::semaphore::{Acquire, AcquireArc};
|
||||
}
|
||||
|
|
421
src/mutex.rs
421
src/mutex.rs
|
@ -1,10 +1,15 @@
|
|||
use std::borrow::Borrow;
|
||||
use std::cell::UnsafeCell;
|
||||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::marker::PhantomData;
|
||||
use std::mem;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::pin::Pin;
|
||||
use std::process;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
// Note: we cannot use `target_family = "wasm"` here because it requires Rust 1.54.
|
||||
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
|
||||
|
@ -12,7 +17,8 @@ use std::time::{Duration, Instant};
|
|||
|
||||
use std::usize;
|
||||
|
||||
use event_listener::Event;
|
||||
use event_listener::{Event, EventListener};
|
||||
use futures_lite::ready;
|
||||
|
||||
/// An async mutex.
|
||||
///
|
||||
|
@ -103,114 +109,10 @@ impl<T: ?Sized> Mutex<T> {
|
|||
/// # })
|
||||
/// ```
|
||||
#[inline]
|
||||
pub async fn lock(&self) -> MutexGuard<'_, T> {
|
||||
if let Some(guard) = self.try_lock() {
|
||||
return guard;
|
||||
}
|
||||
self.acquire_slow().await;
|
||||
MutexGuard(self)
|
||||
}
|
||||
|
||||
/// Slow path for acquiring the mutex.
|
||||
#[cold]
|
||||
async fn acquire_slow(&self) {
|
||||
// Get the current time.
|
||||
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
|
||||
let start = Instant::now();
|
||||
|
||||
loop {
|
||||
// Start listening for events.
|
||||
let listener = self.lock_ops.listen();
|
||||
|
||||
// Try locking if nobody is being starved.
|
||||
match self
|
||||
.state
|
||||
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire)
|
||||
.unwrap_or_else(|x| x)
|
||||
{
|
||||
// Lock acquired!
|
||||
0 => return,
|
||||
|
||||
// Lock is held and nobody is starved.
|
||||
1 => {}
|
||||
|
||||
// Somebody is starved.
|
||||
_ => break,
|
||||
}
|
||||
|
||||
// Wait for a notification.
|
||||
listener.await;
|
||||
|
||||
// Try locking if nobody is being starved.
|
||||
match self
|
||||
.state
|
||||
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire)
|
||||
.unwrap_or_else(|x| x)
|
||||
{
|
||||
// Lock acquired!
|
||||
0 => return,
|
||||
|
||||
// Lock is held and nobody is starved.
|
||||
1 => {}
|
||||
|
||||
// Somebody is starved.
|
||||
_ => {
|
||||
// Notify the first listener in line because we probably received a
|
||||
// notification that was meant for a starved task.
|
||||
self.lock_ops.notify(1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If waiting for too long, fall back to a fairer locking strategy that will prevent
|
||||
// newer lock operations from starving us forever.
|
||||
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
|
||||
if start.elapsed() > Duration::from_micros(500) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Increment the number of starved lock operations.
|
||||
if self.state.fetch_add(2, Ordering::Release) > usize::MAX / 2 {
|
||||
// In case of potential overflow, abort.
|
||||
process::abort();
|
||||
}
|
||||
|
||||
// Decrement the counter when exiting this function.
|
||||
let _call = CallOnDrop(|| {
|
||||
self.state.fetch_sub(2, Ordering::Release);
|
||||
});
|
||||
|
||||
loop {
|
||||
// Start listening for events.
|
||||
let listener = self.lock_ops.listen();
|
||||
|
||||
// Try locking if nobody else is being starved.
|
||||
match self
|
||||
.state
|
||||
.compare_exchange(2, 2 | 1, Ordering::Acquire, Ordering::Acquire)
|
||||
.unwrap_or_else(|x| x)
|
||||
{
|
||||
// Lock acquired!
|
||||
2 => return,
|
||||
|
||||
// Lock is held by someone.
|
||||
s if s % 2 == 1 => {}
|
||||
|
||||
// Lock is available.
|
||||
_ => {
|
||||
// Be fair: notify the first listener and then go wait in line.
|
||||
self.lock_ops.notify(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for a notification.
|
||||
listener.await;
|
||||
|
||||
// Try acquiring the lock without waiting for others.
|
||||
if self.state.fetch_or(1, Ordering::Acquire) % 2 == 0 {
|
||||
return;
|
||||
}
|
||||
pub fn lock(&self) -> Lock<'_, T> {
|
||||
Lock {
|
||||
mutex: self,
|
||||
acquire_slow: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -265,14 +167,6 @@ impl<T: ?Sized> Mutex<T> {
|
|||
}
|
||||
|
||||
impl<T: ?Sized> Mutex<T> {
|
||||
async fn lock_arc_impl(self: Arc<Self>) -> MutexGuardArc<T> {
|
||||
if let Some(guard) = self.try_lock_arc() {
|
||||
return guard;
|
||||
}
|
||||
self.acquire_slow().await;
|
||||
MutexGuardArc(self)
|
||||
}
|
||||
|
||||
/// Acquires the mutex and clones a reference to it.
|
||||
///
|
||||
/// Returns an owned guard that releases the mutex when dropped.
|
||||
|
@ -290,8 +184,8 @@ impl<T: ?Sized> Mutex<T> {
|
|||
/// # })
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn lock_arc(self: &Arc<Self>) -> impl Future<Output = MutexGuardArc<T>> {
|
||||
self.clone().lock_arc_impl()
|
||||
pub fn lock_arc(self: &Arc<Self>) -> LockArc<T> {
|
||||
LockArc(LockArcInnards::Unpolled(self.clone()))
|
||||
}
|
||||
|
||||
/// Attempts to acquire the mutex and clone a reference to it.
|
||||
|
@ -353,6 +247,295 @@ impl<T: Default + ?Sized> Default for Mutex<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// The future returned by [`Mutex::lock`].
|
||||
pub struct Lock<'a, T: ?Sized> {
|
||||
/// Reference to the mutex.
|
||||
mutex: &'a Mutex<T>,
|
||||
|
||||
/// The future that waits for the mutex to become available.
|
||||
acquire_slow: Option<AcquireSlow<&'a Mutex<T>, T>>,
|
||||
}
|
||||
|
||||
impl<'a, T: ?Sized> Unpin for Lock<'a, T> {}
|
||||
|
||||
impl<T: ?Sized> fmt::Debug for Lock<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("Lock { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: ?Sized> Future for Lock<'a, T> {
|
||||
type Output = MutexGuard<'a, T>;
|
||||
|
||||
#[inline]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match this.acquire_slow.as_mut() {
|
||||
None => {
|
||||
// Try the fast path before trying to register slowly.
|
||||
match this.mutex.try_lock() {
|
||||
Some(guard) => return Poll::Ready(guard),
|
||||
None => {
|
||||
this.acquire_slow = Some(AcquireSlow::new(this.mutex));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(acquire_slow) => {
|
||||
// Continue registering slowly.
|
||||
let value = ready!(Pin::new(acquire_slow).poll(cx));
|
||||
return Poll::Ready(MutexGuard(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The future returned by [`Mutex::lock_arc`].
|
||||
pub struct LockArc<T: ?Sized>(LockArcInnards<T>);
|
||||
|
||||
enum LockArcInnards<T: ?Sized> {
|
||||
/// We have not tried to poll the fast path yet.
|
||||
Unpolled(Arc<Mutex<T>>),
|
||||
|
||||
/// We are acquiring the mutex through the slow path.
|
||||
AcquireSlow(AcquireSlow<Arc<Mutex<T>>, T>),
|
||||
|
||||
/// Empty hole to make taking easier.
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Unpin for LockArc<T> {}
|
||||
|
||||
impl<T: ?Sized> fmt::Debug for LockArc<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("LockArc { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Future for LockArc<T> {
|
||||
type Output = MutexGuardArc<T>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match mem::replace(&mut this.0, LockArcInnards::Empty) {
|
||||
LockArcInnards::Unpolled(mutex) => {
|
||||
// Try the fast path before trying to register slowly.
|
||||
match mutex.try_lock_arc() {
|
||||
Some(guard) => return Poll::Ready(guard),
|
||||
None => {
|
||||
*this = LockArc(LockArcInnards::AcquireSlow(AcquireSlow::new(
|
||||
mutex.clone(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LockArcInnards::AcquireSlow(mut acquire_slow) => {
|
||||
// Continue registering slowly.
|
||||
let value = match Pin::new(&mut acquire_slow).poll(cx) {
|
||||
Poll::Pending => {
|
||||
*this = LockArc(LockArcInnards::AcquireSlow(acquire_slow));
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(value) => value,
|
||||
};
|
||||
return Poll::Ready(MutexGuardArc(value));
|
||||
}
|
||||
|
||||
LockArcInnards::Empty => panic!("future polled after completion"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Future for acquiring the mutex slowly.
|
||||
struct AcquireSlow<B: Borrow<Mutex<T>>, T: ?Sized> {
|
||||
/// Reference to the mutex.
|
||||
mutex: Option<B>,
|
||||
|
||||
/// The event listener waiting on the mutex.
|
||||
listener: Option<EventListener>,
|
||||
|
||||
/// The point at which the mutex lock was started.
|
||||
#[cfg(not(any(target_arch = "wasm32", target_os = "wasm64")))]
|
||||
start: Option<Instant>,
|
||||
|
||||
/// This lock operation is starving.
|
||||
starved: bool,
|
||||
|
||||
/// Capture the `T` lifetime.
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<B: Borrow<Mutex<T>> + Unpin, T: ?Sized> Unpin for AcquireSlow<B, T> {}
|
||||
|
||||
impl<T: ?Sized, B: Borrow<Mutex<T>>> AcquireSlow<B, T> {
|
||||
/// Create a new `AcquireSlow` future.
|
||||
#[cold]
|
||||
fn new(mutex: B) -> Self {
|
||||
AcquireSlow {
|
||||
mutex: Some(mutex),
|
||||
listener: None,
|
||||
#[cfg(not(any(target_arch = "wasm32", target_os = "wasm64")))]
|
||||
start: None,
|
||||
starved: false,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Take the mutex reference out, decrementing the counter if necessary.
|
||||
fn take_mutex(&mut self) -> Option<B> {
|
||||
let mutex = self.mutex.take();
|
||||
|
||||
if self.starved {
|
||||
if let Some(mutex) = mutex.as_ref() {
|
||||
// Decrement this counter before we exit.
|
||||
mutex.borrow().state.fetch_sub(2, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
mutex
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized, B: Unpin + Borrow<Mutex<T>>> Future for AcquireSlow<B, T> {
|
||||
type Output = B;
|
||||
|
||||
#[cold]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = &mut *self;
|
||||
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
|
||||
let start = *this.start.get_or_insert_with(Instant::now);
|
||||
let mutex = this
|
||||
.mutex
|
||||
.as_ref()
|
||||
.expect("future polled after completion")
|
||||
.borrow();
|
||||
|
||||
// Only use this hot loop if we aren't currently starved.
|
||||
if !this.starved {
|
||||
loop {
|
||||
// Start listening for events.
|
||||
match &mut this.listener {
|
||||
listener @ None => {
|
||||
// Start listening for events.
|
||||
*listener = Some(mutex.lock_ops.listen());
|
||||
|
||||
// Try locking if nobody is being starved.
|
||||
match mutex
|
||||
.state
|
||||
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire)
|
||||
.unwrap_or_else(|x| x)
|
||||
{
|
||||
// Lock acquired!
|
||||
0 => return Poll::Ready(this.take_mutex().unwrap()),
|
||||
|
||||
// Lock is held and nobody is starved.
|
||||
1 => {}
|
||||
|
||||
// Somebody is starved.
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
Some(ref mut listener) => {
|
||||
// Wait for a notification.
|
||||
ready!(Pin::new(listener).poll(cx));
|
||||
this.listener = None;
|
||||
|
||||
// Try locking if nobody is being starved.
|
||||
match mutex
|
||||
.state
|
||||
.compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire)
|
||||
.unwrap_or_else(|x| x)
|
||||
{
|
||||
// Lock acquired!
|
||||
0 => return Poll::Ready(this.take_mutex().unwrap()),
|
||||
|
||||
// Lock is held and nobody is starved.
|
||||
1 => {}
|
||||
|
||||
// Somebody is starved.
|
||||
_ => {
|
||||
// Notify the first listener in line because we probably received a
|
||||
// notification that was meant for a starved task.
|
||||
mutex.lock_ops.notify(1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If waiting for too long, fall back to a fairer locking strategy that will prevent
|
||||
// newer lock operations from starving us forever.
|
||||
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
|
||||
if start.elapsed() > Duration::from_micros(500) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment the number of starved lock operations.
|
||||
if mutex.state.fetch_add(2, Ordering::Release) > usize::MAX / 2 {
|
||||
// In case of potential overflow, abort.
|
||||
process::abort();
|
||||
}
|
||||
|
||||
// Indicate that we are now starving and will use a fairer locking strategy.
|
||||
this.starved = true;
|
||||
}
|
||||
|
||||
// Fairer locking loop.
|
||||
loop {
|
||||
match &mut this.listener {
|
||||
listener @ None => {
|
||||
// Start listening for events.
|
||||
*listener = Some(mutex.lock_ops.listen());
|
||||
|
||||
// Try locking if nobody else is being starved.
|
||||
match mutex
|
||||
.state
|
||||
.compare_exchange(2, 2 | 1, Ordering::Acquire, Ordering::Acquire)
|
||||
.unwrap_or_else(|x| x)
|
||||
{
|
||||
// Lock acquired!
|
||||
2 => return Poll::Ready(this.take_mutex().unwrap()),
|
||||
|
||||
// Lock is held by someone.
|
||||
s if s % 2 == 1 => {}
|
||||
|
||||
// Lock is available.
|
||||
_ => {
|
||||
// Be fair: notify the first listener and then go wait in line.
|
||||
mutex.lock_ops.notify(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ref mut listener) => {
|
||||
// Wait for a notification.
|
||||
ready!(Pin::new(listener).poll(cx));
|
||||
this.listener = None;
|
||||
|
||||
// Try acquiring the lock without waiting for others.
|
||||
if mutex.state.fetch_or(1, Ordering::Acquire) % 2 == 0 {
|
||||
return Poll::Ready(this.take_mutex().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized, B: Borrow<Mutex<T>>> Drop for AcquireSlow<B, T> {
|
||||
fn drop(&mut self) {
|
||||
// Make sure the starvation counter is decremented.
|
||||
self.take_mutex();
|
||||
}
|
||||
}
|
||||
|
||||
/// A guard that releases the mutex when dropped.
|
||||
pub struct MutexGuard<'a, T: ?Sized>(&'a Mutex<T>);
|
||||
|
||||
|
|
408
src/rwlock.rs
408
src/rwlock.rs
|
@ -1,12 +1,17 @@
|
|||
use std::cell::UnsafeCell;
|
||||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::mem;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::pin::Pin;
|
||||
use std::process;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use event_listener::Event;
|
||||
use event_listener::{Event, EventListener};
|
||||
use futures_lite::ready;
|
||||
|
||||
use crate::futures::Lock;
|
||||
use crate::{Mutex, MutexGuard};
|
||||
|
||||
const WRITER_BIT: usize = 1;
|
||||
|
@ -170,42 +175,11 @@ impl<T: ?Sized> RwLock<T> {
|
|||
/// assert!(lock.try_read().is_some());
|
||||
/// # })
|
||||
/// ```
|
||||
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
|
||||
let mut state = self.state.load(Ordering::Acquire);
|
||||
|
||||
loop {
|
||||
if state & WRITER_BIT == 0 {
|
||||
// Make sure the number of readers doesn't overflow.
|
||||
if state > std::isize::MAX as usize {
|
||||
process::abort();
|
||||
}
|
||||
|
||||
// If nobody is holding a write lock or attempting to acquire it, increment the
|
||||
// number of readers.
|
||||
match self.state.compare_exchange(
|
||||
state,
|
||||
state + ONE_READER,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
Ok(_) => return RwLockReadGuard(self),
|
||||
Err(s) => state = s,
|
||||
}
|
||||
} else {
|
||||
// Start listening for "no writer" events.
|
||||
let listener = self.no_writer.listen();
|
||||
|
||||
// Check again if there's a writer.
|
||||
if self.state.load(Ordering::SeqCst) & WRITER_BIT != 0 {
|
||||
// Wait until the writer is dropped.
|
||||
listener.await;
|
||||
// Notify the next reader waiting in line.
|
||||
self.no_writer.notify(1);
|
||||
}
|
||||
|
||||
// Reload the state.
|
||||
state = self.state.load(Ordering::Acquire);
|
||||
}
|
||||
pub fn read(&self) -> Read<'_, T> {
|
||||
Read {
|
||||
lock: self,
|
||||
state: self.state.load(Ordering::Acquire),
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -289,33 +263,10 @@ impl<T: ?Sized> RwLock<T> {
|
|||
/// *writer = 2;
|
||||
/// # })
|
||||
/// ```
|
||||
pub async fn upgradable_read(&self) -> RwLockUpgradableReadGuard<'_, T> {
|
||||
// First grab the mutex.
|
||||
let lock = self.mutex.lock().await;
|
||||
|
||||
let mut state = self.state.load(Ordering::Acquire);
|
||||
|
||||
// Make sure the number of readers doesn't overflow.
|
||||
if state > std::isize::MAX as usize {
|
||||
process::abort();
|
||||
}
|
||||
|
||||
// Increment the number of readers.
|
||||
loop {
|
||||
match self.state.compare_exchange(
|
||||
state,
|
||||
state + ONE_READER,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
Ok(_) => {
|
||||
return RwLockUpgradableReadGuard {
|
||||
reader: RwLockReadGuard(self),
|
||||
reserved: lock,
|
||||
}
|
||||
}
|
||||
Err(s) => state = s,
|
||||
}
|
||||
pub fn upgradable_read(&self) -> UpgradableRead<'_, T> {
|
||||
UpgradableRead {
|
||||
lock: self,
|
||||
acquire: self.mutex.lock(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -372,30 +323,11 @@ impl<T: ?Sized> RwLock<T> {
|
|||
/// assert!(lock.try_read().is_none());
|
||||
/// # })
|
||||
/// ```
|
||||
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
|
||||
// First grab the mutex.
|
||||
let lock = self.mutex.lock().await;
|
||||
|
||||
// Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled.
|
||||
self.state.fetch_or(WRITER_BIT, Ordering::SeqCst);
|
||||
let guard = RwLockWriteGuard {
|
||||
writer: RwLockWriteGuardInner(self),
|
||||
reserved: lock,
|
||||
};
|
||||
|
||||
// If there are readers, we need to wait for them to finish.
|
||||
while self.state.load(Ordering::SeqCst) != WRITER_BIT {
|
||||
// Start listening for "no readers" events.
|
||||
let listener = self.no_readers.listen();
|
||||
|
||||
// Check again if there are readers.
|
||||
if self.state.load(Ordering::Acquire) != WRITER_BIT {
|
||||
// Wait for the readers to finish.
|
||||
listener.await;
|
||||
}
|
||||
pub fn write(&self) -> Write<'_, T> {
|
||||
Write {
|
||||
lock: self,
|
||||
state: WriteState::Acquiring(self.mutex.lock()),
|
||||
}
|
||||
|
||||
guard
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the inner value.
|
||||
|
@ -448,6 +380,230 @@ impl<T: Default + ?Sized> Default for RwLock<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// The future returned by [`RwLock::read`].
|
||||
pub struct Read<'a, T: ?Sized> {
|
||||
/// The lock that is being acquired.
|
||||
lock: &'a RwLock<T>,
|
||||
|
||||
/// The last-observed state of the lock.
|
||||
state: usize,
|
||||
|
||||
/// The listener for the "no writers" event.
|
||||
listener: Option<EventListener>,
|
||||
}
|
||||
|
||||
impl<T: ?Sized> fmt::Debug for Read<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("Read { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Unpin for Read<'_, T> {}
|
||||
|
||||
impl<'a, T: ?Sized> Future for Read<'a, T> {
|
||||
type Output = RwLockReadGuard<'a, T>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
if this.state & WRITER_BIT == 0 {
|
||||
// Make sure the number of readers doesn't overflow.
|
||||
if this.state > std::isize::MAX as usize {
|
||||
process::abort();
|
||||
}
|
||||
|
||||
// If nobody is holding a write lock or attempting to acquire it, increment the
|
||||
// number of readers.
|
||||
match this.lock.state.compare_exchange(
|
||||
this.state,
|
||||
this.state + ONE_READER,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
Ok(_) => return Poll::Ready(RwLockReadGuard(this.lock)),
|
||||
Err(s) => this.state = s,
|
||||
}
|
||||
} else {
|
||||
// Start listening for "no writer" events.
|
||||
let load_ordering = match &mut this.listener {
|
||||
listener @ None => {
|
||||
*listener = Some(this.lock.no_writer.listen());
|
||||
|
||||
// Make sure there really is no writer.
|
||||
Ordering::SeqCst
|
||||
}
|
||||
|
||||
Some(ref mut listener) => {
|
||||
// Wait for the writer to finish.
|
||||
ready!(Pin::new(listener).poll(cx));
|
||||
this.listener = None;
|
||||
|
||||
// Notify the next reader waiting in list.
|
||||
this.lock.no_writer.notify(1);
|
||||
|
||||
// Check the state again.
|
||||
Ordering::Acquire
|
||||
}
|
||||
};
|
||||
|
||||
// Reload the state.
|
||||
this.state = this.lock.state.load(load_ordering);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The future returned by [`RwLock::upgradable_read`].
|
||||
pub struct UpgradableRead<'a, T: ?Sized> {
|
||||
/// The lock that is being acquired.
|
||||
lock: &'a RwLock<T>,
|
||||
|
||||
/// The mutex we are trying to acquire.
|
||||
acquire: Lock<'a, ()>,
|
||||
}
|
||||
|
||||
impl<T: ?Sized> fmt::Debug for UpgradableRead<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("UpgradableRead { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Unpin for UpgradableRead<'_, T> {}
|
||||
|
||||
impl<'a, T: ?Sized> Future for UpgradableRead<'a, T> {
|
||||
type Output = RwLockUpgradableReadGuard<'a, T>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// Acquire the mutex.
|
||||
let mutex_guard = ready!(Pin::new(&mut this.acquire).poll(cx));
|
||||
|
||||
let mut state = this.lock.state.load(Ordering::Acquire);
|
||||
|
||||
// Make sure the number of readers doesn't overflow.
|
||||
if state > std::isize::MAX as usize {
|
||||
process::abort();
|
||||
}
|
||||
|
||||
// Increment the number of readers.
|
||||
loop {
|
||||
match this.lock.state.compare_exchange(
|
||||
state,
|
||||
state + ONE_READER,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
Ok(_) => {
|
||||
return Poll::Ready(RwLockUpgradableReadGuard {
|
||||
reader: RwLockReadGuard(this.lock),
|
||||
reserved: mutex_guard,
|
||||
});
|
||||
}
|
||||
Err(s) => state = s,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The future returned by [`RwLock::write`].
|
||||
pub struct Write<'a, T: ?Sized> {
|
||||
/// The lock that is being acquired.
|
||||
lock: &'a RwLock<T>,
|
||||
|
||||
/// Current state fof this future.
|
||||
state: WriteState<'a, T>,
|
||||
}
|
||||
|
||||
enum WriteState<'a, T: ?Sized> {
|
||||
/// We are currently acquiring the inner mutex.
|
||||
Acquiring(Lock<'a, ()>),
|
||||
|
||||
/// We are currently waiting for readers to finish.
|
||||
WaitingReaders {
|
||||
/// Our current write guard.
|
||||
guard: Option<RwLockWriteGuard<'a, T>>,
|
||||
|
||||
/// The listener for the "no readers" event.
|
||||
listener: Option<EventListener>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<T: ?Sized> fmt::Debug for Write<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("Write { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Unpin for Write<'_, T> {}
|
||||
|
||||
impl<'a, T: ?Sized> Future for Write<'a, T> {
|
||||
type Output = RwLockWriteGuard<'a, T>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match &mut this.state {
|
||||
WriteState::Acquiring(lock) => {
|
||||
// First grab the mutex.
|
||||
let mutex_guard = ready!(Pin::new(lock).poll(cx));
|
||||
|
||||
// Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled.
|
||||
let new_state = this.lock.state.fetch_or(WRITER_BIT, Ordering::SeqCst);
|
||||
let guard = RwLockWriteGuard {
|
||||
writer: RwLockWriteGuardInner(this.lock),
|
||||
reserved: mutex_guard,
|
||||
};
|
||||
|
||||
// If we just acquired the writer lock, return it.
|
||||
if new_state == WRITER_BIT {
|
||||
return Poll::Ready(guard);
|
||||
}
|
||||
|
||||
// Start waiting for the readers to finish.
|
||||
this.state = WriteState::WaitingReaders {
|
||||
guard: Some(guard),
|
||||
listener: Some(this.lock.no_readers.listen()),
|
||||
};
|
||||
}
|
||||
|
||||
WriteState::WaitingReaders {
|
||||
guard,
|
||||
ref mut listener,
|
||||
} => {
|
||||
let load_ordering = if listener.is_some() {
|
||||
Ordering::Acquire
|
||||
} else {
|
||||
Ordering::SeqCst
|
||||
};
|
||||
|
||||
// Check the state again.
|
||||
if this.lock.state.load(load_ordering) == WRITER_BIT {
|
||||
// We are the only ones holding the lock, return it.
|
||||
return Poll::Ready(guard.take().unwrap());
|
||||
}
|
||||
|
||||
// Wait for the readers to finish.
|
||||
match listener {
|
||||
None => {
|
||||
// Register a listener.
|
||||
*listener = Some(this.lock.no_readers.listen());
|
||||
}
|
||||
|
||||
Some(ref mut evl) => {
|
||||
// Wait for the readers to finish.
|
||||
ready!(Pin::new(evl).poll(cx));
|
||||
*listener = None;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A guard that releases the read lock when dropped.
|
||||
pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RwLock<T>);
|
||||
|
||||
|
@ -585,7 +741,7 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> {
|
|||
/// *writer = 2;
|
||||
/// # })
|
||||
/// ```
|
||||
pub async fn upgrade(guard: Self) -> RwLockWriteGuard<'a, T> {
|
||||
pub fn upgrade(guard: Self) -> Upgrade<'a, T> {
|
||||
// Set `WRITER_BIT` and decrement the number of readers at the same time.
|
||||
guard
|
||||
.reader
|
||||
|
@ -596,19 +752,10 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> {
|
|||
// Convert into a write guard that unsets `WRITER_BIT` in case this future is canceled.
|
||||
let guard = guard.into_writer();
|
||||
|
||||
// If there are readers, we need to wait for them to finish.
|
||||
while guard.writer.0.state.load(Ordering::SeqCst) != WRITER_BIT {
|
||||
// Start listening for "no readers" events.
|
||||
let listener = guard.writer.0.no_readers.listen();
|
||||
|
||||
// Check again if there are readers.
|
||||
if guard.writer.0.state.load(Ordering::Acquire) != WRITER_BIT {
|
||||
// Wait for the readers to finish.
|
||||
listener.await;
|
||||
}
|
||||
Upgrade {
|
||||
guard: Some(guard),
|
||||
listener: None,
|
||||
}
|
||||
|
||||
guard
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -632,6 +779,67 @@ impl<T: ?Sized> Deref for RwLockUpgradableReadGuard<'_, T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// The future returned by [`RwLockUpgradableReadGuard::upgrade`].
|
||||
pub struct Upgrade<'a, T: ?Sized> {
|
||||
/// The guard that we are upgrading to.
|
||||
guard: Option<RwLockWriteGuard<'a, T>>,
|
||||
|
||||
/// The event listener we are waiting on.
|
||||
listener: Option<EventListener>,
|
||||
}
|
||||
|
||||
impl<T: ?Sized> fmt::Debug for Upgrade<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Upgrade").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Unpin for Upgrade<'_, T> {}
|
||||
|
||||
impl<'a, T: ?Sized> Future for Upgrade<'a, T> {
|
||||
type Output = RwLockWriteGuard<'a, T>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
let guard = this
|
||||
.guard
|
||||
.as_mut()
|
||||
.expect("cannot poll future after completion");
|
||||
|
||||
// If there are readers, we need to wait for them to finish.
|
||||
loop {
|
||||
let load_ordering = if this.listener.is_some() {
|
||||
Ordering::Acquire
|
||||
} else {
|
||||
Ordering::SeqCst
|
||||
};
|
||||
|
||||
// See if the number of readers is zero.
|
||||
let state = guard.writer.0.state.load(load_ordering);
|
||||
if state == WRITER_BIT {
|
||||
break;
|
||||
}
|
||||
|
||||
// If there are readers, wait for them to finish.
|
||||
match &mut this.listener {
|
||||
listener @ None => {
|
||||
// Start listening for "no readers" events.
|
||||
*listener = Some(guard.writer.0.no_readers.listen());
|
||||
}
|
||||
|
||||
Some(ref mut listener) => {
|
||||
// Wait for the readers to finish.
|
||||
ready!(Pin::new(listener).poll(cx));
|
||||
this.listener = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We are done.
|
||||
Poll::Ready(this.guard.take().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
struct RwLockWriteGuardInner<'a, T: ?Sized>(&'a RwLock<T>);
|
||||
|
||||
impl<T: ?Sized> Drop for RwLockWriteGuardInner<'_, T> {
|
||||
|
|
133
src/semaphore.rs
133
src/semaphore.rs
|
@ -1,8 +1,12 @@
|
|||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use event_listener::Event;
|
||||
use event_listener::{Event, EventListener};
|
||||
use futures_lite::ready;
|
||||
|
||||
/// A counter for limiting the number of concurrent operations.
|
||||
#[derive(Debug)]
|
||||
|
@ -80,18 +84,10 @@ impl Semaphore {
|
|||
/// let guard = s.acquire().await;
|
||||
/// # });
|
||||
/// ```
|
||||
pub async fn acquire(&self) -> SemaphoreGuard<'_> {
|
||||
let mut listener = None;
|
||||
|
||||
loop {
|
||||
if let Some(guard) = self.try_acquire() {
|
||||
return guard;
|
||||
}
|
||||
|
||||
match listener.take() {
|
||||
None => listener = Some(self.event.listen()),
|
||||
Some(l) => l.await,
|
||||
}
|
||||
pub fn acquire(&self) -> Acquire<'_> {
|
||||
Acquire {
|
||||
semaphore: self,
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -136,21 +132,6 @@ impl Semaphore {
|
|||
}
|
||||
}
|
||||
|
||||
async fn acquire_arc_impl(self: Arc<Self>) -> SemaphoreGuardArc {
|
||||
let mut listener = None;
|
||||
|
||||
loop {
|
||||
if let Some(guard) = self.try_acquire_arc() {
|
||||
return guard;
|
||||
}
|
||||
|
||||
match listener.take() {
|
||||
None => listener = Some(self.event.listen()),
|
||||
Some(l) => l.await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Waits for an owned permit for a concurrent operation.
|
||||
///
|
||||
/// Returns a guard that releases the permit when dropped.
|
||||
|
@ -166,8 +147,100 @@ impl Semaphore {
|
|||
/// let guard = s.acquire_arc().await;
|
||||
/// # });
|
||||
/// ```
|
||||
pub fn acquire_arc(self: &Arc<Self>) -> impl Future<Output = SemaphoreGuardArc> {
|
||||
self.clone().acquire_arc_impl()
|
||||
pub fn acquire_arc(self: &Arc<Self>) -> AcquireArc {
|
||||
AcquireArc {
|
||||
semaphore: self.clone(),
|
||||
listener: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The future returned by [`Semaphore::acquire`].
|
||||
pub struct Acquire<'a> {
|
||||
/// The semaphore being acquired.
|
||||
semaphore: &'a Semaphore,
|
||||
|
||||
/// The listener waiting on the semaphore.
|
||||
listener: Option<EventListener>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Acquire<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("Acquire { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpin for Acquire<'_> {}
|
||||
|
||||
impl<'a> Future for Acquire<'a> {
|
||||
type Output = SemaphoreGuard<'a>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match this.semaphore.try_acquire() {
|
||||
Some(guard) => return Poll::Ready(guard),
|
||||
None => {
|
||||
// Wait on the listener.
|
||||
match &mut this.listener {
|
||||
listener @ None => {
|
||||
*listener = Some(this.semaphore.event.listen());
|
||||
}
|
||||
Some(ref mut listener) => {
|
||||
ready!(Pin::new(listener).poll(cx));
|
||||
this.listener = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The future returned by [`Semaphore::acquire_arc`].
|
||||
pub struct AcquireArc {
|
||||
/// The semaphore being acquired.
|
||||
semaphore: Arc<Semaphore>,
|
||||
|
||||
/// The listener waiting on the semaphore.
|
||||
listener: Option<EventListener>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for AcquireArc {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("AcquireArc { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpin for AcquireArc {}
|
||||
|
||||
impl Future for AcquireArc {
|
||||
type Output = SemaphoreGuardArc;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.get_mut();
|
||||
|
||||
loop {
|
||||
match this.semaphore.try_acquire_arc() {
|
||||
Some(guard) => {
|
||||
this.listener = None;
|
||||
return Poll::Ready(guard);
|
||||
}
|
||||
None => {
|
||||
// Wait on the listener.
|
||||
match &mut this.listener.take() {
|
||||
listener @ None => {
|
||||
*listener = Some(this.semaphore.event.listen());
|
||||
}
|
||||
Some(ref mut listener) => {
|
||||
ready!(Pin::new(listener).poll(cx));
|
||||
this.listener = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue