diff --git a/src/lib.rs b/src/lib.rs index d8e59ca..904803f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ #![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] +use std::cell::RefCell; use std::fmt; use std::future::Future; use std::marker::PhantomData; @@ -229,29 +230,56 @@ impl<'a> Executor<'a> { let runner = Runner::new(self.state()); let mut rng = fastrand::Rng::new(); - // A future that runs tasks forever. - let run_forever = async { - loop { - for _ in 0..200 { - let runnable = runner.runnable(&mut rng).await; - runnable.run(); - } - future::yield_now().await; - } - }; + // Set the local queue while we're running. + LocalQueue::set(self.state(), &runner.local, { + let runner = &runner; + async move { + // A future that runs tasks forever. + let run_forever = async { + loop { + for _ in 0..200 { + let runnable = runner.runnable(&mut rng).await; + runnable.run(); + } + future::yield_now().await; + } + }; - // Run `future` and `run_forever` concurrently until `future` completes. - future.or(run_forever).await + // Run `future` and `run_forever` concurrently until `future` completes. + future.or(run_forever).await + } + }) + .await } /// Returns a function that schedules a runnable task when it gets woken up. fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static { let state = self.state().clone(); - // TODO(stjepang): If possible, push into the current local queue and notify the ticker. + // If possible, push into the current local queue and notify the ticker. move |runnable| { - state.queue.push(runnable).unwrap(); - state.notify(); + let mut runnable = Some(runnable); + + // Try to push into the local queue. + LocalQueue::with(|local_queue| { + // Make sure that we don't accidentally push to an executor that isn't ours. + if !std::ptr::eq(local_queue.state, &*state) { + return; + } + + if let Err(e) = local_queue.queue.push(runnable.take().unwrap()) { + runnable = Some(e.into_inner()); + return; + } + + local_queue.waker.wake_by_ref(); + }); + + // If the local queue push failed, just push to the global queue. + if let Some(runnable) = runnable { + state.queue.push(runnable).unwrap(); + state.notify(); + } } } @@ -819,6 +847,97 @@ impl Drop for Runner<'_> { } } +/// The state of the currently running local queue. +struct LocalQueue { + /// The pointer to the state of the executor. + /// + /// Used to make sure we don't push runnables to the wrong executor. + state: *const State, + + /// The concurrent queue. + queue: Arc>, + + /// The waker for the runnable. + waker: Waker, +} + +impl LocalQueue { + /// Run a function with the current local queue. + fn with(f: impl FnOnce(&LocalQueue) -> R) -> Option { + std::thread_local! { + /// The current local queue. + static LOCAL_QUEUE: RefCell> = RefCell::new(None); + } + + impl LocalQueue { + /// Run a function with a set local queue. + async fn set( + state: &State, + queue: &Arc>, + fut: F, + ) -> F::Output + where + F: Future, + { + // Store the local queue and the current waker. + let mut old = with_waker(|waker| { + LOCAL_QUEUE.with(move |slot| { + slot.borrow_mut().replace(LocalQueue { + state: state as *const State, + queue: queue.clone(), + waker: waker.clone(), + }) + }) + }) + .await; + + // Restore the old local queue on drop. + let _guard = CallOnDrop(move || { + let old = old.take(); + let _ = LOCAL_QUEUE.try_with(move |slot| { + *slot.borrow_mut() = old; + }); + }); + + // Pin the future. + futures_lite::pin!(fut); + + // Run it such that the waker is updated every time it's polled. + future::poll_fn(move |cx| { + LOCAL_QUEUE + .try_with({ + let waker = cx.waker(); + move |slot| { + let mut slot = slot.borrow_mut(); + let qaw = slot.as_mut().expect("missing local queue"); + + // If we've been replaced, just ignore the slot. + if !Arc::ptr_eq(&qaw.queue, queue) { + return; + } + + // Update the waker, if it has changed. + if !qaw.waker.will_wake(waker) { + qaw.waker = waker.clone(); + } + } + }) + .ok(); + + // Poll the future. + fut.as_mut().poll(cx) + }) + .await + } + } + + LOCAL_QUEUE + .try_with(|local_queue| local_queue.borrow().as_ref().map(f)) + .ok() + .flatten() + } +} + /// Steals some items from one queue into another. fn steal(src: &ConcurrentQueue, dest: &ConcurrentQueue) { // Half of `src`'s length rounded up. @@ -911,10 +1030,19 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_ } /// Runs a closure when dropped. -struct CallOnDrop(F); +struct CallOnDrop(F); -impl Drop for CallOnDrop { +impl Drop for CallOnDrop { fn drop(&mut self) { (self.0)(); } } + +/// Run a closure with the current waker. +fn with_waker R, R>(f: F) -> impl Future { + let mut f = Some(f); + future::poll_fn(move |cx| { + let f = f.take().unwrap(); + Poll::Ready(f(cx.waker())) + }) +}