Support racy initialization of an Executor's state

Fixes #89. Uses @notgull's suggestion of using a `AtomicPtr` with a racy initialization instead of a `OnceCell`.

For the addition of more `unsafe`, I added the `clippy::undocumented_unsafe_blocks` lint at a warn, and fixed a few of the remaining open clippy issues (i.e. `Waker::clone_from` already handling the case where they're equal).

Removing `async_lock` as a dependency shouldn't be a SemVer breaking change.
This commit is contained in:
James Liu 2024-04-08 19:41:14 -07:00 committed by GitHub
parent 4b37c612f6
commit 649bdfda23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 43 deletions

View File

@ -15,7 +15,6 @@ categories = ["asynchronous", "concurrency"]
exclude = ["/.*"]
[dependencies]
async-lock = "3.0.0"
async-task = "4.4.0"
concurrent-queue = "2.0.0"
fastrand = "2.0.0"

View File

@ -25,7 +25,12 @@
//! future::block_on(ex.run(task));
//! ```
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
#![warn(
missing_docs,
missing_debug_implementations,
rust_2018_idioms,
clippy::undocumented_unsafe_blocks
)]
#![doc(
html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
)]
@ -37,11 +42,10 @@ use std::fmt;
use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::sync::{Arc, Mutex, RwLock, TryLockError};
use std::task::{Poll, Waker};
use async_lock::OnceCell;
use async_task::{Builder, Runnable};
use concurrent_queue::ConcurrentQueue;
use futures_lite::{future, prelude::*};
@ -76,13 +80,15 @@ pub use async_task::Task;
/// ```
pub struct Executor<'a> {
/// The executor state.
state: OnceCell<Arc<State>>,
state: AtomicPtr<State>,
/// Makes the `'a` lifetime invariant.
_marker: PhantomData<std::cell::UnsafeCell<&'a ()>>,
}
// SAFETY: Executor stores no thread local state that can be accessed via other thread.
unsafe impl Send for Executor<'_> {}
// SAFETY: Executor internally synchronizes all of it's operations internally.
unsafe impl Sync for Executor<'_> {}
impl UnwindSafe for Executor<'_> {}
@ -106,7 +112,7 @@ impl<'a> Executor<'a> {
/// ```
pub const fn new() -> Executor<'a> {
Executor {
state: OnceCell::new(),
state: AtomicPtr::new(std::ptr::null_mut()),
_marker: PhantomData,
}
}
@ -231,7 +237,7 @@ impl<'a> Executor<'a> {
// Remove the task from the set of active tasks when the future finishes.
let entry = active.vacant_entry();
let index = entry.key();
let state = self.state().clone();
let state = self.state_as_arc();
let future = async move {
let _guard = CallOnDrop(move || drop(state.active.lock().unwrap().try_remove(index)));
future.await
@ -361,7 +367,7 @@ impl<'a> Executor<'a> {
/// Returns a function that schedules a runnable task when it gets woken up.
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
let state = self.state().clone();
let state = self.state_as_arc();
// TODO: If possible, push into the current local queue and notify the ticker.
move |runnable| {
@ -370,34 +376,73 @@ impl<'a> Executor<'a> {
}
}
/// Returns a reference to the inner state.
fn state(&self) -> &Arc<State> {
#[cfg(not(target_family = "wasm"))]
{
return self.state.get_or_init_blocking(|| Arc::new(State::new()));
/// Returns a pointer to the inner state.
#[inline]
fn state_ptr(&self) -> *const State {
#[cold]
fn alloc_state(atomic_ptr: &AtomicPtr<State>) -> *mut State {
let state = Arc::new(State::new());
// TODO: Switch this to use cast_mut once the MSRV can be bumped past 1.65
let ptr = Arc::into_raw(state) as *mut State;
if let Err(actual) = atomic_ptr.compare_exchange(
std::ptr::null_mut(),
ptr,
Ordering::AcqRel,
Ordering::Acquire,
) {
// SAFETY: This was just created from Arc::into_raw.
drop(unsafe { Arc::from_raw(ptr) });
actual
} else {
ptr
}
}
// Some projects use this on WASM for some reason. In this case get_or_init_blocking
// doesn't work. Just poll the future once and panic if there is contention.
#[cfg(target_family = "wasm")]
future::block_on(future::poll_once(
self.state.get_or_init(|| async { Arc::new(State::new()) }),
))
.expect("encountered contention on WASM")
let mut ptr = self.state.load(Ordering::Acquire);
if ptr.is_null() {
ptr = alloc_state(&self.state);
}
ptr
}
/// Returns a reference to the inner state.
#[inline]
fn state(&self) -> &State {
// SAFETY: So long as an Executor lives, it's state pointer will always be valid
// when accessed through state_ptr.
unsafe { &*self.state_ptr() }
}
// Clones the inner state Arc
#[inline]
fn state_as_arc(&self) -> Arc<State> {
// SAFETY: So long as an Executor lives, it's state pointer will always be a valid
// Arc when accessed through state_ptr.
let arc = unsafe { Arc::from_raw(self.state_ptr()) };
let clone = arc.clone();
std::mem::forget(arc);
clone
}
}
impl Drop for Executor<'_> {
fn drop(&mut self) {
if let Some(state) = self.state.get() {
let mut active = state.active.lock().unwrap_or_else(|e| e.into_inner());
for w in active.drain() {
w.wake();
}
drop(active);
while state.queue.pop().is_ok() {}
let ptr = *self.state.get_mut();
if ptr.is_null() {
return;
}
// SAFETY: As ptr is not null, it was allocated via Arc::new and converted
// via Arc::into_raw in state_ptr.
let state = unsafe { Arc::from_raw(ptr) };
let mut active = state.active.lock().unwrap_or_else(|e| e.into_inner());
for w in active.drain() {
w.wake();
}
drop(active);
while state.queue.pop().is_ok() {}
}
}
@ -718,9 +763,7 @@ impl Sleepers {
fn update(&mut self, id: usize, waker: &Waker) -> bool {
for item in &mut self.wakers {
if item.0 == id {
if !item.1.will_wake(waker) {
item.1.clone_from(waker);
}
item.1.clone_from(waker);
return false;
}
}
@ -1006,21 +1049,24 @@ fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
/// Debug implementation for `Executor` and `LocalExecutor`.
fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// Get a reference to the state.
let state = match executor.state.get() {
Some(state) => state,
None => {
// The executor has not been initialized.
struct Uninitialized;
let ptr = executor.state.load(Ordering::Acquire);
if ptr.is_null() {
// The executor has not been initialized.
struct Uninitialized;
impl fmt::Debug for Uninitialized {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<uninitialized>")
}
impl fmt::Debug for Uninitialized {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<uninitialized>")
}
return f.debug_tuple(name).field(&Uninitialized).finish();
}
};
return f.debug_tuple(name).field(&Uninitialized).finish();
}
// SAFETY: If the state pointer is not null, it must have been
// allocated properly by Arc::new and converted via Arc::into_raw
// in state_ptr.
let state = unsafe { &*ptr };
/// Debug wrapper for the number of active tasks.
struct ActiveTasks<'a>(&'a Mutex<Slab<Waker>>);