Add windows support

This commit is contained in:
Stjepan Glavina 2020-02-09 23:35:17 +01:00
parent e66d3587ec
commit 390ee1f8ce
2 changed files with 243 additions and 95 deletions

View File

@ -13,13 +13,18 @@ crossbeam-utils = "0.7.0"
futures-core = "0.3.3"
futures-io = "0.3.3"
futures-util = { version = "0.3.3", default-features = false, features = [] }
nix = "0.16.1"
once_cell = "1.3.1"
parking_lot = "0.10.0"
pin-utils = "0.1.0-alpha.4"
slab = "0.4.2"
socket2 = "0.3.11"
[target.'cfg(unix)'.dependencies]
nix = "0.16.1"
[target.'cfg(windows)'.dependencies]
wepoll-binding = "1.0.5"
[dev-dependencies]
futures = { version = "0.3.3", default-features = false, features = ["std"] }
hyper = { version = "0.13", default-features = false }

View File

@ -1,17 +1,18 @@
#![forbid(unsafe_code)]
#![cfg_attr(docsrs, feature(doc_cfg))]
// TODO: #![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
#[cfg(not(any(target_os = "linux", target_os = "android", target_os = "windows")))]
compile_error!("smol does not support this target OS");
use std::cell::RefCell;
use std::collections::{BTreeMap, VecDeque};
use std::convert::TryInto;
use std::error::Error;
use std::fmt::Debug;
use std::future::Future;
use std::io::{self, Read, Write};
use std::mem;
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream};
use std::panic::catch_unwind;
use std::path::Path;
use std::pin::Pin;
@ -21,6 +22,15 @@ use std::task::{Context, Poll, Waker};
use std::thread;
use std::time::{Duration, Instant};
#[cfg(unix)]
use std::os::unix::{
io::{AsRawFd, RawFd},
net::{SocketAddr as UnixSocketAddr, UnixDatagram, UnixListener, UnixStream},
};
#[cfg(target_os = "windows")]
use std::os::windows::io::{AsRawSocket, RawSocket};
use crossbeam_channel as channel;
use crossbeam_utils::sync::Parker;
use futures_core::stream::Stream;
@ -28,16 +38,18 @@ use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use futures_util::future;
use futures_util::io::{AsyncReadExt, AsyncWriteExt};
use futures_util::stream::{self, StreamExt};
use nix::sys::epoll::{
epoll_create1, epoll_ctl, epoll_wait, EpollCreateFlags, EpollEvent, EpollFlags, EpollOp,
};
use once_cell::sync::Lazy;
use parking_lot::{Condvar, Mutex};
use slab::Slab;
use socket2::{Domain, Protocol, Socket, Type};
#[cfg(not(any(target_os = "linux", target_os = "android")))]
compile_error!("smol does not support this target OS");
#[cfg(target_os = "linux")]
use nix::sys::epoll::{
epoll_create1, epoll_ctl, epoll_wait, EpollCreateFlags, EpollEvent, EpollFlags, EpollOp,
};
#[cfg(target_os = "windows")]
use wepoll_binding::{Epoll, EpollFlags, Events};
// TODO: fix unwraps
// TODO: if epoll/kqueue/wepoll gets EINTR, then retry - or maybe just call notify()
@ -81,7 +93,7 @@ impl Poller {
sock2.set_recv_buffer_size(1)?;
let registry = Registry::create()?;
registry.register(sock2.as_raw_fd())?;
registry.register(&sock2)?;
Ok(Poller {
registry,
@ -91,7 +103,7 @@ impl Poller {
})
}
fn poll(&self) {
fn poll(&self) -> io::Result<()> {
let interrupted = self.reset();
let next_timer = self.registry.poll_timers();
@ -100,12 +112,14 @@ impl Poller {
} else {
next_timer.map(|when| Instant::now().saturating_duration_since(when))
};
self.registry.wait_io(timeout);
self.registry.wait_io(timeout)?;
Ok(())
}
fn poll_quick(&self) {
fn poll_quick(&self) -> io::Result<()> {
self.registry.poll_timers();
self.registry.wait_io(Some(Duration::from_secs(0)));
self.registry.wait_io(Some(Duration::from_secs(0)))?;
Ok(())
}
/// Sets the interrupt flag and writes to the wakeup socket.
@ -144,15 +158,42 @@ impl Poller {
// ----- Registry -----
struct Entry {
#[cfg(unix)]
fd: RawFd,
#[cfg(windows)]
socket: RawSocket,
index: usize,
readers: Mutex<Vec<Waker>>,
writers: Mutex<Vec<Waker>>,
}
#[cfg(unix)]
impl AsRawFd for Entry {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
#[cfg(windows)]
impl AsRawSocket for Entry {
fn as_raw_socket(&self) -> RawSocket {
self.socket
}
}
struct Registry {
#[cfg(target_os = "linux")]
epoll: RawFd,
#[cfg(target_os = "linux")]
events: Mutex<Box<[EpollEvent]>>,
#[cfg(target_os = "windows")]
epoll: Epoll,
#[cfg(target_os = "windows")]
events: Mutex<Events>,
io_handles: Mutex<Slab<Arc<Entry>>>,
timers: Mutex<BTreeMap<(Instant, usize), Waker>>,
}
@ -160,49 +201,170 @@ struct Registry {
impl Registry {
fn create() -> io::Result<Registry> {
Ok(Registry {
#[cfg(target_os = "linux")]
epoll: epoll_create1(EpollCreateFlags::EPOLL_CLOEXEC).map_err(io_err)?,
#[cfg(target_os = "linux")]
events: Mutex::new(vec![EpollEvent::empty(); 1000].into_boxed_slice()),
#[cfg(target_os = "windows")]
epoll: Epoll::new()?,
#[cfg(target_os = "windows")]
events: Mutex::new(Events::with_capacity(1000)),
io_handles: Mutex::new(Slab::new()),
timers: Mutex::new(BTreeMap::new()),
})
}
// TODO: insert/delete terminology?
fn register(&self, fd: RawFd) -> io::Result<Arc<Entry>> {
let entry = {
let mut io_handles = self.io_handles.lock();
let vacant = io_handles.vacant_entry();
let entry = Arc::new(Entry {
fd,
index: vacant.key(),
readers: Mutex::new(Vec::new()),
writers: Mutex::new(Vec::new()),
});
vacant.insert(entry.clone());
entry
};
fn register(
&self,
#[cfg(unix)] source: &dyn AsRawFd,
#[cfg(windows)] source: &dyn RawSocket,
) -> io::Result<Arc<Entry>> {
let mut io_handles = self.io_handles.lock();
let vacant = io_handles.vacant_entry();
let index = vacant.key();
#[cfg(target_os = "linux")]
epoll_ctl(
self.epoll,
EpollOp::EpollCtlAdd,
fd,
source.as_raw_fd(),
Some(&mut EpollEvent::new(
EpollFlags::EPOLLET
EpollFlags::EPOLLONESHOT
| EpollFlags::EPOLLIN
| EpollFlags::EPOLLOUT
| EpollFlags::EPOLLRDHUP,
entry.index as u64,
index as u64,
)),
)
.map_err(io_err)?;
// TODO: if epoll fails, remove the entry
#[cfg(target_os = "windows")]
self.epoll.register(
source,
EventFlag::ONESHOT | EventFlag::IN | EventFlag::OUT | EventFlag::RDHUP,
index as u64,
)?;
let entry = Arc::new(Entry {
#[cfg(unix)]
fd: source.as_raw_fd(),
#[cfg(windows)]
socket: source.as_raw_socket(),
index,
readers: Mutex::new(Vec::new()),
writers: Mutex::new(Vec::new()),
});
vacant.insert(entry.clone());
Ok(entry)
}
fn unregister(&self, entry: &Entry) -> io::Result<()> {
self.io_handles.lock().remove(entry.index);
epoll_ctl(self.epoll, EpollOp::EpollCtlDel, entry.fd, None).map_err(io_err)?;
fn deregister(&self, entry: &Entry) -> io::Result<()> {
let mut io_handles = self.io_handles.lock();
io_handles.remove(entry.index);
#[cfg(target_os = "linux")]
epoll_ctl(self.epoll, EpollOp::EpollCtlDel, entry.as_raw_fd(), None).map_err(io_err)?;
#[cfg(target_os = "windows")]
self.epoll.deregister(entry)?;
Ok(())
}
fn wait_io(&self, timeout: Option<Duration>) -> io::Result<()> {
let mut events = if timeout == Some(Duration::from_secs(0)) {
match self.events.try_lock() {
None => return Ok(()),
Some(e) => e,
}
} else {
self.events.lock()
};
#[cfg(target_os = "linux")]
let (n, iter) = {
let timeout_ms = timeout
.and_then(|t| t.as_millis().try_into().ok())
.unwrap_or(-1);
let n = epoll_wait(self.epoll, &mut events, timeout_ms).map_err(io_err)?;
(n, &events[..n])
};
#[cfg(target_os = "windows")]
let (n, iter) = {
events.clear();
let n = self.epoll.poll(&mut events, timeout)?;
(n, events.iter())
};
let mut wakers = VecDeque::new();
if n > 0 {
let io_handles = self.io_handles.lock();
for ev in iter {
#[cfg(target_os = "linux")]
let (is_read, is_write, index) = (
ev.events() != EpollFlags::EPOLLOUT,
ev.events() != EpollFlags::EPOLLIN,
ev.data() as usize,
);
#[cfg(target_os = "windows")]
let (is_read, is_write, index) = (
ev.flags() != EventFlag::OUT,
ev.flags() != EventFlag::IN,
ev.data() as usize,
);
// In order to minimize latencies, wake writers before readers.
// Source: https://twitter.com/kingprotty/status/1222152589405384705?s=19
if let Some(entry) = io_handles.get(index) {
if is_read {
for w in entry.readers.lock().drain(..) {
wakers.push_back(w);
}
}
if is_write {
for w in entry.writers.lock().drain(..) {
wakers.push_front(w);
}
}
#[cfg(target_os = "linux")]
epoll_ctl(
self.epoll,
EpollOp::EpollCtlMod,
entry.fd,
Some(&mut EpollEvent::new(
EpollFlags::EPOLLONESHOT
| EpollFlags::EPOLLIN
| EpollFlags::EPOLLOUT
| EpollFlags::EPOLLRDHUP,
entry.index as u64,
)),
)
.map_err(io_err)?;
#[cfg(target_os = "windows")]
self.epoll.reregister(
entry,
EventFlag::ONESHOT | EventFlag::IN | EventFlag::OUT | EventFlag::RDHUP,
entry.index as u64
)?;
}
}
}
// Wake up ready I/O.
for waker in wakers {
waker.wake();
}
Ok(())
}
@ -223,59 +385,11 @@ impl Registry {
next_timer
}
fn wait_io(&self, timeout: Option<Duration>) {
let mut events = if timeout == Some(Duration::from_secs(0)) {
match self.events.try_lock() {
None => return,
Some(e) => e,
}
} else {
self.events.lock()
};
let timeout_ms = timeout
.and_then(|t| t.as_millis().try_into().ok())
.unwrap_or(-1);
// TODO: handle unwrap
let n = epoll_wait(self.epoll, &mut events, timeout_ms).unwrap();
let mut wakers = VecDeque::new();
if n > 0 {
let io_handles = self.io_handles.lock();
for ev in &events[..n] {
let is_read = ev.events() != EpollFlags::EPOLLOUT;
let is_write = ev.events() != EpollFlags::EPOLLIN;
let index = ev.data() as usize;
// In order to minimize latencies, wake writers before readers.
// Source: https://twitter.com/kingprotty/status/1222152589405384705?s=19
if let Some(entry) = io_handles.get(index) {
if is_read {
for w in entry.readers.lock().drain(..) {
wakers.push_back(w);
}
}
if is_write {
for w in entry.writers.lock().drain(..) {
wakers.push_front(w);
}
}
}
}
}
// Wake up ready I/O.
for waker in wakers {
waker.wake();
}
}
}
/// Converts any error into an I/O error.
fn io_err(err: impl Error + Send + Sync + 'static) -> io::Error {
#[cfg(unix)]
fn io_err(err: impl std::error::Error + Send + Sync + 'static) -> io::Error {
io::Error::new(io::ErrorKind::Other, Box::new(err))
}
@ -329,7 +443,7 @@ pub fn run<T>(future: impl Future<Output = T>) -> T {
while !ready.load(Ordering::SeqCst) {
if runs >= 64 {
runs = 0;
POLLER.poll_quick();
POLLER.poll_quick().unwrap();
}
match EXECUTOR.receiver.try_recv() {
@ -341,7 +455,7 @@ pub fn run<T>(future: impl Future<Output = T>) -> T {
Err(_) => {
runs = 0;
fails += 1;
POLLER.poll_quick();
POLLER.poll_quick().unwrap();
if fails <= 1 {
continue;
@ -362,13 +476,22 @@ pub fn run<T>(future: impl Future<Output = T>) -> T {
*m = true;
drop(m);
// TODO: if this panics, set m to false and notify
POLLER.poll();
let _guard = {
struct OnDrop<F: FnMut()>(F);
impl<F: FnMut()> Drop for OnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}
OnDrop(|| {
let mut m = EXECUTOR.mutex.lock();
*m = false;
EXECUTOR.cvar.notify_one();
})
};
m = EXECUTOR.mutex.lock();
*m = false;
EXECUTOR.cvar.notify_one();
POLLER.poll().unwrap();
}
}
}
@ -686,11 +809,25 @@ pub struct Async<T> {
entry: Arc<Entry>,
}
#[cfg(any(unix, docsrs))]
#[cfg_attr(docsrs, doc(cfg(unix)))]
impl<T: AsRawFd> Async<T> {
/// Converts a non-blocking I/O handle into an async I/O handle.
pub fn nonblocking(source: T) -> io::Result<Async<T>> {
Ok(Async {
entry: POLLER.registry.register(source.as_raw_fd())?,
entry: POLLER.registry.register(&source)?,
source: Box::new(source),
})
}
}
#[cfg(any(windows, docsrs))]
#[cfg_attr(docsrs, doc(cfg(windows)))]
impl<T: AsRawSocket> Async<T> {
/// Converts a non-blocking I/O handle into an async I/O handle.
pub fn nonblocking(source: T) -> io::Result<Async<T>> {
Ok(Async {
entry: POLLER.registry.register(&source)?,
source: Box::new(source),
})
}
@ -776,8 +913,8 @@ impl<T> Async<T> {
impl<T> Drop for Async<T> {
fn drop(&mut self) {
// Ignore errors because an event in oneshot mode may unregister the fd before we do.
let _ = POLLER.registry.unregister(&self.entry);
// Ignore errors because an event in oneshot mode may deregister the fd before we do.
let _ = POLLER.registry.deregister(&self.entry);
}
}
@ -988,6 +1125,8 @@ impl Async<UdpSocket> {
}
}
#[cfg(any(unix, docsrs))]
#[cfg_attr(docsrs, doc(cfg(unix)))]
impl Async<UnixListener> {
/// Creates a listener bound to the specified path.
pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixListener>> {
@ -1012,6 +1151,8 @@ impl Async<UnixListener> {
}
}
#[cfg(any(unix, docsrs))]
#[cfg_attr(docsrs, doc(cfg(unix)))]
impl Async<UnixStream> {
/// Connects to the specified path.
pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixStream>> {
@ -1029,6 +1170,8 @@ impl Async<UnixStream> {
}
}
#[cfg(any(unix, docsrs))]
#[cfg_attr(docsrs, doc(cfg(unix)))]
impl Async<UnixDatagram> {
/// Creates a socket bound to the specified path.
pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixDatagram>> {