diff --git a/src/lib.rs b/src/lib.rs index aa1dbf6..531420c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,6 +53,7 @@ use std::future::Future; +use async_channel::{unbounded, Receiver}; use async_executor::{Executor, LocalExecutor}; use easy_parallel::Parallel; @@ -107,21 +108,21 @@ pub mod prelude { /// }) /// ``` pub fn run(future: impl Future) -> T { - setup(num_cpus::get(), future) -} - -#[cfg(not(feature = "tokio02"))] -fn setup(num_threads: usize, future: impl Future) -> T { - let ex = Executor::new(); - let local_ex = LocalExecutor::new(); - // A channel that coordinates shutdown when the main future completes. - let (trigger, shutdown) = async_channel::unbounded::<()>(); + let (trigger, shutdown) = unbounded::<()>(); let future = async move { let _trigger = trigger; // Dropped at the end of this async block. future.await }; + setup(num_cpus::get(), shutdown, future) +} + +#[cfg(not(feature = "tokio02"))] +fn setup(num_threads: usize, shutdown: Receiver<()>, future: impl Future) -> T { + let ex = Executor::new(); + let local_ex = LocalExecutor::new(); + Parallel::new() .each(0..num_threads, |_| ex.run(shutdown.recv())) .finish(|| ex.enter(|| local_ex.run(future))) @@ -129,17 +130,7 @@ fn setup(num_threads: usize, future: impl Future) -> T { } #[cfg(feature = "tokio02")] -fn setup(num_threads: usize, future: impl Future) -> T { - let ex = Executor::new(); - let local_ex = LocalExecutor::new(); - - // A channel that signals shutdown to the thread pool when the main future completes. - let (s, shutdown) = async_channel::unbounded::<()>(); - let future = async move { - let _s = s; // Drops sender at the end of this async block. - future.await - }; - +fn setup(num_threads: usize, shutdown: Receiver<()>, future: impl Future) -> T { // A minimal tokio runtime. let mut rt = tokio::runtime::Builder::new() .enable_all() @@ -148,6 +139,9 @@ fn setup(num_threads: usize, future: impl Future) -> T { .expect("cannot start tokio runtime"); let handle = rt.handle().clone(); + let ex = Executor::new(); + let local_ex = LocalExecutor::new(); + Parallel::new() .add(|| ex.enter(|| rt.block_on(shutdown.recv()))) .each(0..num_threads, |_| handle.enter(|| ex.run(shutdown.recv())))