Add work-stealing

This commit is contained in:
Stjepan Glavina 2020-02-16 14:43:40 +01:00
parent 5e1a64ff54
commit 57ef197d0e
3 changed files with 93 additions and 20 deletions

View File

@ -8,17 +8,18 @@ license = "MIT OR Apache-2.0"
[dependencies]
async-task = "1.3.0"
crossbeam-channel = "0.4.0"
crossbeam-deque = "0.7.2"
crossbeam-utils = "0.7.0"
futures-core = "0.3.3"
futures-io = "0.3.3"
futures-util = { version = "0.3.3", default-features = false, features = ["std", "io"] }
libc = "0.2.66"
once_cell = "1.3.1"
parking_lot = "0.10.0"
pin-utils = "0.1.0-alpha.4"
scopeguard = "1.0.0"
slab = "0.4.2"
socket2 = "0.3.11"
libc = "0.2.66"
[target.'cfg(unix)'.dependencies]
nix = "0.16.1"

View File

@ -10,6 +10,11 @@ async fn hello(_: Request<Body>) -> Result<Response<Body>, Infallible> {
}
pub fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Create a thread pool.
for _ in 0..num_cpus::get().max(1) {
std::thread::spawn(|| smol::run(futures::future::pending::<()>()));
}
smol::run(async {
let addr = "127.0.0.1:3000";
let listener = Async::<TcpListener>::bind(addr)?;

View File

@ -26,7 +26,7 @@ use std::panic::catch_unwind;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::thread;
use std::thread::{self, ThreadId};
use std::time::{Duration, Instant};
#[cfg(unix)]
@ -35,12 +35,13 @@ use std::{
path::Path,
};
use crossbeam_channel as channel;
use crossbeam_utils::sync::Parker;
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use crossbeam_utils::sync::{Parker, ShardedLock};
use futures_core::stream::Stream;
use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use futures_util::future;
use futures_util::io::{AsyncReadExt, AsyncWriteExt};
use futures_util::lock;
use futures_util::stream::{self, StreamExt};
use once_cell::sync::Lazy;
use parking_lot::{Condvar, Mutex, MutexGuard};
@ -62,34 +63,78 @@ static EXECUTOR: Lazy<Executor> = Lazy::new(|| Executor::create().expect("cannot
/// A runnable future, ready for execution.
type Runnable = async_task::Task<()>;
thread_local! {
static WORKER: RefCell<Option<Worker<Runnable>>> = RefCell::new(None);
}
struct Executor {
receiver: channel::Receiver<Runnable>,
queue: channel::Sender<Runnable>,
injector: Injector<Runnable>,
stealers: ShardedLock<Vec<(ThreadId, Stealer<Runnable>)>>,
interrupt: Async<IoFlag>,
}
impl Executor {
fn create() -> io::Result<Executor> {
let (sender, receiver) = channel::unbounded::<Runnable>();
Ok(Executor {
receiver,
queue: sender,
injector: Injector::new(),
stealers: ShardedLock::new(Vec::new()),
interrupt: IoFlag::create().and_then(Async::nonblocking)?,
})
}
fn schedule(&self, runnable: Runnable) {
self.queue.send(runnable).unwrap();
WORKER.with(|worker| match worker.borrow().as_ref() {
Some(w) => w.push(runnable),
None => self.injector.push(runnable),
});
self.interrupt();
}
fn find_quick(&self) -> Option<Runnable> {
WORKER.with(|worker| {
let worker = worker.borrow();
let worker = worker.as_ref().unwrap();
if let Some(r) = worker.pop() {
return Some(r);
}
loop {
match self.injector.steal_batch_and_pop(&worker) {
Steal::Success(r) => return Some(r),
Steal::Empty => return None,
Steal::Retry => {}
}
}
})
}
fn find_runnable(&self) -> Option<Runnable> {
self.receiver.try_recv().ok()
if let Some(r) = self.find_quick() {
return Some(r);
}
self.poll_quick().unwrap();
if let Some(r) = self.find_quick() {
return Some(r);
}
WORKER.with(|worker| {
let worker = worker.borrow();
let worker = worker.as_ref().unwrap();
let stealers = self.stealers.read().unwrap();
for (_, s) in stealers.iter() {
loop {
match s.steal_batch_and_pop(&worker) {
Steal::Success(r) => return Some(r),
Steal::Empty => break,
Steal::Retry => {}
}
}
}
None
})
}
fn run_until(&self, io_flag: &IoFlag) -> io::Result<()> {
let mut runs = 0;
let mut fails = 0;
while !io_flag.get() {
if runs > 50 {
@ -97,11 +142,7 @@ impl Executor {
self.poll_quick()?;
} else if let Some(runnable) = self.find_runnable() {
runs += 1;
fails = 0;
let _ = catch_unwind(|| runnable.run());
} else if fails == 0 {
fails += 1;
self.poll_quick()?;
} else {
break;
}
@ -172,6 +213,32 @@ pub fn run<T>(future: impl Future<Output = T>) -> T {
let f = io_flag.clone();
let waker = async_task::waker_fn(move || f.get_ref().set());
WORKER.with(|worker| {
let mut worker = worker.borrow_mut();
if worker.is_some() {
// TODO: already registered, panic because recursive run()
}
let w = Worker::new_fifo();
let s = w.stealer();
*worker = Some(w);
let id = thread::current().id();
EXECUTOR.stealers.write().unwrap().push((id, s));
});
scopeguard::defer! {
WORKER.with(|worker| {
let worker = worker.borrow_mut().take().unwrap();
while let Some(r) = worker.pop() {
EXECUTOR.injector.push(r);
}
let id = thread::current().id();
EXECUTOR.stealers.write().unwrap().retain(|pair| pair.0 != id);
})
}
loop {
match future.as_mut().poll(&mut Context::from_waker(&waker)) {
Poll::Ready(val) => return val,
@ -510,7 +577,7 @@ struct Source {
struct Reactor {
sys: sys::Reactor,
sources: Mutex<Slab<Arc<Source>>>,
events: futures_util::lock::Mutex<sys::Events>,
events: lock::Mutex<sys::Events>,
}
impl Reactor {
@ -518,7 +585,7 @@ impl Reactor {
Ok(Reactor {
sys: sys::Reactor::create()?,
sources: Mutex::new(Slab::new()),
events: futures_util::lock::Mutex::new(sys::Events::new()),
events: lock::Mutex::new(sys::Events::new()),
})
}
@ -562,7 +629,7 @@ impl Reactor {
struct Poller<'a> {
reactor: &'a Reactor,
events: futures_util::lock::MutexGuard<'a, sys::Events>,
events: lock::MutexGuard<'a, sys::Events>,
}
impl Poller<'_> {