diff --git a/src/lib.rs b/src/lib.rs index 06a885e..ff2212f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ use std::sync::{Arc, Mutex, RwLock}; use std::task::{Context, Poll, Waker}; use concurrent_queue::ConcurrentQueue; -use futures_lite::{future, ready, FutureExt}; +use futures_lite::{future, FutureExt}; /// A runnable future, ready for execution. /// @@ -384,13 +384,13 @@ impl Executor { pub fn try_tick(&self) -> bool { match self.state().queue.pop() { Err(_) => false, - Ok(r) => { + Ok(runnable) => { // Notify another ticker now to pick up where this ticker left off, just in case // running the task takes a long time. self.state().notify(); // Run the task. - r.run(); + runnable.run(); true } } @@ -416,11 +416,9 @@ impl Executor { /// future::block_on(ex.tick()); // runs the task /// ``` pub async fn tick(&self) { - // Create a ticker that doesn't use sharding. - let ticker = Ticker::new(self.state()); - - // Keep trying until a single `poll_tick()` is successful. - future::poll_fn(|cx| ticker.poll_tick(cx)).await + let state = self.state(); + let runnable = Ticker::new(state).runnable(|| state.queue.pop().ok()).await; + runnable.run(); } /// Runs the executor until the given future completes. @@ -439,23 +437,21 @@ impl Executor { /// assert_eq!(res, 6); /// ``` pub async fn run(&self, future: impl Future) -> T { - // Create a ticker that uses sharding. let runner = Runner::new(self.state()); - // A future that ticks the executor forever. - let tick_forever = future::poll_fn(|cx| { - // Run a batch of tasks. - for _ in 0..200 { - ready!(runner.poll_tick(cx)); + // A future that runs tasks forever. + let run_forever = async { + loop { + for _ in 0..200 { + let runnable = runner.runnable().await; + runnable.run(); + } + future::yield_now().await; } + }; - // If there are more tasks, yield. - cx.waker().wake_by_ref(); - Poll::Pending - }); - - // Run `future` and `tick_forever` concurrently until `future` completes. - future.or(tick_forever).await + // Run `future` and `run_forever` concurrently until `future` completes. + future.or(run_forever).await } /// Returns a function that schedules a runnable task when it gets woken up. @@ -550,33 +546,32 @@ impl Ticker<'_> { } } - /// Attempts to execute a single task. - /// - /// This method takes a scheduled task and polls its future. - fn poll_tick(&self, cx: &mut Context<'_>) -> Poll<()> { - loop { - match self.state.queue.pop() { - Err(_) => { - // Move to sleeping and unnotified state. - if !self.sleep(cx.waker()) { - // If already sleeping and unnotified, return. - return Poll::Pending; + /// Finds a task to run. + async fn runnable(&self, mut search: impl FnMut() -> Option) -> Runnable { + future::poll_fn(|cx| { + loop { + match search() { + None => { + // Move to sleeping and unnotified state. + if !self.sleep(cx.waker()) { + // If already sleeping and unnotified, return. + return Poll::Pending; + } + } + Some(r) => { + // Wake up. + self.wake(); + + // Notify another ticker now to pick up where this ticker left off, just in + // case running the task takes a long time. + self.state.notify(); + + return Poll::Ready(r); } } - Ok(r) => { - // Wake up. - self.wake(); - - // Notify another ticker now to pick up where this ticker left off, just in - // case running the task takes a long time. - self.state.notify(); - - // Run the task. - r.run(); - return Poll::Ready(()); - } } - } + }) + .await } } @@ -629,77 +624,55 @@ impl Runner<'_> { runner } - /// Attempts to execute a single task. - /// - /// This method takes a scheduled task and polls its future. - fn poll_tick(&self, cx: &mut Context<'_>) -> Poll<()> { - loop { - match self.search() { - None => { - // Move to sleeping and unnotified state. - if !self.ticker.sleep(cx.waker()) { - // If already sleeping and unnotified, return. - return Poll::Pending; + /// Finds a task to run. + async fn runnable(&self) -> Runnable { + let runnable = self + .ticker + .runnable(|| { + // Try the shard. + if let Ok(r) = self.shard.pop() { + return Some(r); + } + + // Try stealing from the global queue. + if let Ok(r) = self.state.queue.pop() { + steal(&self.state.queue, &self.shard); + return Some(r); + } + + // Try stealing from other shards. + let shards = self.state.shards.read().unwrap(); + + // Pick a random starting point in the iterator list and rotate the list. + let n = shards.len(); + let start = fastrand::usize(..n); + let iter = shards.iter().chain(shards.iter()).skip(start).take(n); + + // Remove this ticker's shard. + let iter = iter.filter(|shard| !Arc::ptr_eq(shard, &self.shard)); + + // Try stealing from each shard in the list. + for shard in iter { + steal(shard, &self.shard); + if let Ok(r) = self.shard.pop() { + return Some(r); } } - Some(r) => { - // Wake up. - self.ticker.wake(); - // Notify another ticker now to pick up where this ticker left off, just in - // case running the task takes a long time. - self.state.notify(); + None + }) + .await; - // Bump the ticker. - let ticks = self.ticks.get(); - self.ticks.set(ticks.wrapping_add(1)); + // Bump the ticker. + let ticks = self.ticks.get(); + self.ticks.set(ticks.wrapping_add(1)); - // Steal tasks from the global queue to ensure fair task scheduling. - if ticks % 64 == 0 { - steal(&self.state.queue, &self.shard); - } - - // Run the task. - r.run(); - return Poll::Ready(()); - } - } - } - } - - /// Finds the next task to run. - fn search(&self) -> Option { - // Try the shard. - if let Ok(r) = self.shard.pop() { - return Some(r); - } - - // Try stealing from the global queue. - if let Ok(r) = self.state.queue.pop() { + // Steal tasks from the global queue to ensure fair task scheduling. + if ticks % 64 == 0 { steal(&self.state.queue, &self.shard); - return Some(r); } - // Try stealing from other shards. - let shards = self.state.shards.read().unwrap(); - - // Pick a random starting point in the iterator list and rotate the list. - let n = shards.len(); - let start = fastrand::usize(..n); - let iter = shards.iter().chain(shards.iter()).skip(start).take(n); - - // Remove this ticker's shard. - let iter = iter.filter(|shard| !Arc::ptr_eq(shard, &self.shard)); - - // Try stealing from each shard in the list. - for shard in iter { - steal(shard, &self.shard); - if let Ok(r) = self.shard.pop() { - return Some(r); - } - } - - None + runnable } }