Support racy initialization of an Executor's state
Fixes #89. Uses @notgull's suggestion of using a `AtomicPtr` with a racy initialization instead of a `OnceCell`. For the addition of more `unsafe`, I added the `clippy::undocumented_unsafe_blocks` lint at a warn, and fixed a few of the remaining open clippy issues (i.e. `Waker::clone_from` already handling the case where they're equal). Removing `async_lock` as a dependency shouldn't be a SemVer breaking change.
This commit is contained in:
parent
4b37c612f6
commit
649bdfda23
|
@ -15,7 +15,6 @@ categories = ["asynchronous", "concurrency"]
|
|||
exclude = ["/.*"]
|
||||
|
||||
[dependencies]
|
||||
async-lock = "3.0.0"
|
||||
async-task = "4.4.0"
|
||||
concurrent-queue = "2.0.0"
|
||||
fastrand = "2.0.0"
|
||||
|
|
130
src/lib.rs
130
src/lib.rs
|
@ -25,7 +25,12 @@
|
|||
//! future::block_on(ex.run(task));
|
||||
//! ```
|
||||
|
||||
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
|
||||
#![warn(
|
||||
missing_docs,
|
||||
missing_debug_implementations,
|
||||
rust_2018_idioms,
|
||||
clippy::undocumented_unsafe_blocks
|
||||
)]
|
||||
#![doc(
|
||||
html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
|
||||
)]
|
||||
|
@ -37,11 +42,10 @@ use std::fmt;
|
|||
use std::marker::PhantomData;
|
||||
use std::panic::{RefUnwindSafe, UnwindSafe};
|
||||
use std::rc::Rc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
use std::task::{Poll, Waker};
|
||||
|
||||
use async_lock::OnceCell;
|
||||
use async_task::{Builder, Runnable};
|
||||
use concurrent_queue::ConcurrentQueue;
|
||||
use futures_lite::{future, prelude::*};
|
||||
|
@ -76,13 +80,15 @@ pub use async_task::Task;
|
|||
/// ```
|
||||
pub struct Executor<'a> {
|
||||
/// The executor state.
|
||||
state: OnceCell<Arc<State>>,
|
||||
state: AtomicPtr<State>,
|
||||
|
||||
/// Makes the `'a` lifetime invariant.
|
||||
_marker: PhantomData<std::cell::UnsafeCell<&'a ()>>,
|
||||
}
|
||||
|
||||
// SAFETY: Executor stores no thread local state that can be accessed via other thread.
|
||||
unsafe impl Send for Executor<'_> {}
|
||||
// SAFETY: Executor internally synchronizes all of it's operations internally.
|
||||
unsafe impl Sync for Executor<'_> {}
|
||||
|
||||
impl UnwindSafe for Executor<'_> {}
|
||||
|
@ -106,7 +112,7 @@ impl<'a> Executor<'a> {
|
|||
/// ```
|
||||
pub const fn new() -> Executor<'a> {
|
||||
Executor {
|
||||
state: OnceCell::new(),
|
||||
state: AtomicPtr::new(std::ptr::null_mut()),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -231,7 +237,7 @@ impl<'a> Executor<'a> {
|
|||
// Remove the task from the set of active tasks when the future finishes.
|
||||
let entry = active.vacant_entry();
|
||||
let index = entry.key();
|
||||
let state = self.state().clone();
|
||||
let state = self.state_as_arc();
|
||||
let future = async move {
|
||||
let _guard = CallOnDrop(move || drop(state.active.lock().unwrap().try_remove(index)));
|
||||
future.await
|
||||
|
@ -361,7 +367,7 @@ impl<'a> Executor<'a> {
|
|||
|
||||
/// Returns a function that schedules a runnable task when it gets woken up.
|
||||
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
|
||||
let state = self.state().clone();
|
||||
let state = self.state_as_arc();
|
||||
|
||||
// TODO: If possible, push into the current local queue and notify the ticker.
|
||||
move |runnable| {
|
||||
|
@ -370,34 +376,73 @@ impl<'a> Executor<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the inner state.
|
||||
fn state(&self) -> &Arc<State> {
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
{
|
||||
return self.state.get_or_init_blocking(|| Arc::new(State::new()));
|
||||
/// Returns a pointer to the inner state.
|
||||
#[inline]
|
||||
fn state_ptr(&self) -> *const State {
|
||||
#[cold]
|
||||
fn alloc_state(atomic_ptr: &AtomicPtr<State>) -> *mut State {
|
||||
let state = Arc::new(State::new());
|
||||
// TODO: Switch this to use cast_mut once the MSRV can be bumped past 1.65
|
||||
let ptr = Arc::into_raw(state) as *mut State;
|
||||
if let Err(actual) = atomic_ptr.compare_exchange(
|
||||
std::ptr::null_mut(),
|
||||
ptr,
|
||||
Ordering::AcqRel,
|
||||
Ordering::Acquire,
|
||||
) {
|
||||
// SAFETY: This was just created from Arc::into_raw.
|
||||
drop(unsafe { Arc::from_raw(ptr) });
|
||||
actual
|
||||
} else {
|
||||
ptr
|
||||
}
|
||||
}
|
||||
|
||||
// Some projects use this on WASM for some reason. In this case get_or_init_blocking
|
||||
// doesn't work. Just poll the future once and panic if there is contention.
|
||||
#[cfg(target_family = "wasm")]
|
||||
future::block_on(future::poll_once(
|
||||
self.state.get_or_init(|| async { Arc::new(State::new()) }),
|
||||
))
|
||||
.expect("encountered contention on WASM")
|
||||
let mut ptr = self.state.load(Ordering::Acquire);
|
||||
if ptr.is_null() {
|
||||
ptr = alloc_state(&self.state);
|
||||
}
|
||||
ptr
|
||||
}
|
||||
|
||||
/// Returns a reference to the inner state.
|
||||
#[inline]
|
||||
fn state(&self) -> &State {
|
||||
// SAFETY: So long as an Executor lives, it's state pointer will always be valid
|
||||
// when accessed through state_ptr.
|
||||
unsafe { &*self.state_ptr() }
|
||||
}
|
||||
|
||||
// Clones the inner state Arc
|
||||
#[inline]
|
||||
fn state_as_arc(&self) -> Arc<State> {
|
||||
// SAFETY: So long as an Executor lives, it's state pointer will always be a valid
|
||||
// Arc when accessed through state_ptr.
|
||||
let arc = unsafe { Arc::from_raw(self.state_ptr()) };
|
||||
let clone = arc.clone();
|
||||
std::mem::forget(arc);
|
||||
clone
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Executor<'_> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(state) = self.state.get() {
|
||||
let mut active = state.active.lock().unwrap_or_else(|e| e.into_inner());
|
||||
for w in active.drain() {
|
||||
w.wake();
|
||||
}
|
||||
drop(active);
|
||||
|
||||
while state.queue.pop().is_ok() {}
|
||||
let ptr = *self.state.get_mut();
|
||||
if ptr.is_null() {
|
||||
return;
|
||||
}
|
||||
|
||||
// SAFETY: As ptr is not null, it was allocated via Arc::new and converted
|
||||
// via Arc::into_raw in state_ptr.
|
||||
let state = unsafe { Arc::from_raw(ptr) };
|
||||
|
||||
let mut active = state.active.lock().unwrap_or_else(|e| e.into_inner());
|
||||
for w in active.drain() {
|
||||
w.wake();
|
||||
}
|
||||
drop(active);
|
||||
|
||||
while state.queue.pop().is_ok() {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -718,9 +763,7 @@ impl Sleepers {
|
|||
fn update(&mut self, id: usize, waker: &Waker) -> bool {
|
||||
for item in &mut self.wakers {
|
||||
if item.0 == id {
|
||||
if !item.1.will_wake(waker) {
|
||||
item.1.clone_from(waker);
|
||||
}
|
||||
item.1.clone_from(waker);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -1006,21 +1049,24 @@ fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
|
|||
/// Debug implementation for `Executor` and `LocalExecutor`.
|
||||
fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
// Get a reference to the state.
|
||||
let state = match executor.state.get() {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
// The executor has not been initialized.
|
||||
struct Uninitialized;
|
||||
let ptr = executor.state.load(Ordering::Acquire);
|
||||
if ptr.is_null() {
|
||||
// The executor has not been initialized.
|
||||
struct Uninitialized;
|
||||
|
||||
impl fmt::Debug for Uninitialized {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("<uninitialized>")
|
||||
}
|
||||
impl fmt::Debug for Uninitialized {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("<uninitialized>")
|
||||
}
|
||||
|
||||
return f.debug_tuple(name).field(&Uninitialized).finish();
|
||||
}
|
||||
};
|
||||
|
||||
return f.debug_tuple(name).field(&Uninitialized).finish();
|
||||
}
|
||||
|
||||
// SAFETY: If the state pointer is not null, it must have been
|
||||
// allocated properly by Arc::new and converted via Arc::into_raw
|
||||
// in state_ptr.
|
||||
let state = unsafe { &*ptr };
|
||||
|
||||
/// Debug wrapper for the number of active tasks.
|
||||
struct ActiveTasks<'a>(&'a Mutex<Slab<Waker>>);
|
||||
|
|
Loading…
Reference in New Issue