From 3e16c34504f132a1a0c44106068354a8e6c44cc5 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Wed, 27 May 2020 16:18:12 +0200 Subject: [PATCH] feat: allow recursive block_on calls --- src/block_on.rs | 2 +- src/context.rs | 19 ++++++++++++++++++- tests/block_on.rs | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 tests/block_on.rs diff --git a/src/block_on.rs b/src/block_on.rs index 1636462..9b9a72b 100644 --- a/src/block_on.rs +++ b/src/block_on.rs @@ -57,7 +57,7 @@ pub fn block_on(future: impl Future) -> T { CACHE.with(|cache| { // Panic if `block_on()` is called recursively. - let (parker, waker) = &mut *cache.try_borrow_mut().expect("recursive `block_on()`"); + let (parker, waker) = &*cache.borrow(); // If enabled, set up tokio before execution begins. context::enter(|| { diff --git a/src/context.rs b/src/context.rs index 630b81b..9af790e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -10,10 +10,27 @@ pub(crate) fn enter(f: impl FnOnce() -> T) -> T { #[cfg(feature = "tokio02")] { use once_cell::sync::Lazy; + use std::cell::Cell; use tokio::runtime::Runtime; + thread_local! { + /// The level of nested `enter` calls we are in, to ensure that the outer most always has a + /// runtime spawned. + static NESTING: Cell = Cell::new(0); + } + static RT: Lazy = Lazy::new(|| Runtime::new().expect("cannot initialize tokio")); - RT.enter(f) + NESTING.with(|nesting| { + let res = if nesting.get() == 0 { + nesting.replace(1); + RT.enter(f) + } else { + nesting.replace(nesting.get() + 1); + f() + }; + nesting.replace(nesting.get() - 1); + res + }) } } diff --git a/tests/block_on.rs b/tests/block_on.rs new file mode 100644 index 0000000..8ceab27 --- /dev/null +++ b/tests/block_on.rs @@ -0,0 +1,37 @@ +use futures_util::future; + +#[test] +fn smoke() { + std::thread::spawn(|| { + smol::run(future::pending::<()>()); + }); + let res = smol::block_on(async { 1 + 2 }); + assert_eq!(res, 3); +} + +#[test] +#[should_panic = "boom"] +fn panic() { + std::thread::spawn(|| { + smol::run(future::pending::<()>()); + }); + smol::block_on(async { + // This panic should get propagated into the parent thread. + panic!("boom"); + }); +} + +#[test] +fn nested_block_on() { + std::thread::spawn(|| { + smol::run(future::pending::<()>()); + }); + + let x = smol::block_on(async { + let a = smol::block_on(async { smol::block_on(async { future::ready(3).await }) }); + let b = smol::block_on(async { smol::block_on(async { future::ready(2).await }) }); + a + b + }); + + assert_eq!(x, 3 + 2); +}