m(windows): Reimplement Wepoll in Rust (#88)

Reimplements the C-based wepoll backend in Rust, using some handwritten code. This PR also implements bindings to the I/O Completion Ports and \Device\Afd APIs. For more information on the latter, see my blog post on the subject: https://notgull.github.io/device-afd/

Note that the IOCP API is wrapped using a `Pin`-oriented "CompletionHandle" system that is relatively brittle. This should be replaced with a better model when one becomes available.
This commit is contained in:
John Nunley 2023-03-05 16:25:25 -08:00 committed by GitHub
parent e85331c437
commit 24900fb662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1826 additions and 267 deletions

View File

@ -7,10 +7,10 @@ version = "2.5.2"
authors = ["Stjepan Glavina <stjepang@gmail.com>"] authors = ["Stjepan Glavina <stjepang@gmail.com>"]
edition = "2018" edition = "2018"
rust-version = "1.47" rust-version = "1.47"
description = "Portable interface to epoll, kqueue, event ports, and wepoll" description = "Portable interface to epoll, kqueue, event ports, and IOCP"
license = "Apache-2.0 OR MIT" license = "Apache-2.0 OR MIT"
repository = "https://github.com/smol-rs/polling" repository = "https://github.com/smol-rs/polling"
keywords = ["mio", "epoll", "kqueue", "iocp", "wepoll"] keywords = ["mio", "epoll", "kqueue", "iocp"]
categories = ["asynchronous", "network-programming", "os"] categories = ["asynchronous", "network-programming", "os"]
exclude = ["/.*"] exclude = ["/.*"]
@ -32,13 +32,19 @@ autocfg = "1"
libc = "0.2.77" libc = "0.2.77"
[target.'cfg(windows)'.dependencies] [target.'cfg(windows)'.dependencies]
wepoll-ffi = { version = "0.1.2", features = ["null-overlapped-wakeups-patch"] } bitflags = "1.3.2"
concurrent-queue = "2.1.0"
pin-project-lite = "0.2.9"
[target.'cfg(windows)'.dependencies.windows-sys] [target.'cfg(windows)'.dependencies.windows-sys]
version = "0.45" version = "0.45"
features = [ features = [
"Win32_Networking_WinSock",
"Win32_System_IO", "Win32_System_IO",
"Win32_Foundation" "Win32_System_LibraryLoader",
"Win32_System_WindowsProgramming",
"Win32_Storage_FileSystem",
"Win32_Foundation",
] ]
[dev-dependencies] [dev-dependencies]

View File

@ -9,7 +9,7 @@ https://crates.io/crates/polling)
[![Documentation](https://docs.rs/polling/badge.svg)]( [![Documentation](https://docs.rs/polling/badge.svg)](
https://docs.rs/polling) https://docs.rs/polling)
Portable interface to epoll, kqueue, event ports, and wepoll. Portable interface to epoll, kqueue, event ports, and IOCP.
Supported platforms: Supported platforms:
- [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android
@ -17,7 +17,7 @@ Supported platforms:
DragonFly BSD DragonFly BSD
- [event ports](https://illumos.org/man/port_create): illumos, Solaris - [event ports](https://illumos.org/man/port_create): illumos, Solaris
- [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems
- [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+) - [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+)
Polling is done in oneshot mode, which means interest in I/O events needs to be reset after Polling is done in oneshot mode, which means interest in I/O events needs to be reset after
an event is delivered if we're interested in the next event of the same kind. an event is delivered if we're interested in the next event of the same kind.

608
src/iocp/afd.rs Normal file
View File

@ -0,0 +1,608 @@
//! Safe wrapper around \Device\Afd
use super::port::{Completion, CompletionHandle};
use std::cell::UnsafeCell;
use std::fmt;
use std::io;
use std::marker::{PhantomData, PhantomPinned};
use std::mem::{size_of, transmute, MaybeUninit};
use std::os::windows::prelude::{AsRawHandle, RawHandle, RawSocket};
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Once;
use windows_sys::Win32::Foundation::{
CloseHandle, HANDLE, HINSTANCE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS,
UNICODE_STRING,
};
use windows_sys::Win32::Networking::WinSock::{
WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE_POLL, SOCKET_ERROR,
};
use windows_sys::Win32::Storage::FileSystem::{
FILE_OPEN, FILE_SHARE_READ, FILE_SHARE_WRITE, SYNCHRONIZE,
};
use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleW, GetProcAddress};
use windows_sys::Win32::System::WindowsProgramming::{IO_STATUS_BLOCK, OBJECT_ATTRIBUTES};
#[derive(Default)]
#[repr(C)]
pub(super) struct AfdPollInfo {
/// The timeout for this poll.
timeout: i64,
/// The number of handles being polled.
handle_count: u32,
/// Whether or not this poll is exclusive for this handle.
exclusive: u32,
/// The handles to poll.
handles: [AfdPollHandleInfo; 1],
}
#[derive(Default)]
#[repr(C)]
struct AfdPollHandleInfo {
/// The handle to poll.
handle: HANDLE,
/// The events to poll for.
events: AfdPollMask,
/// The status of the poll.
status: NTSTATUS,
}
impl AfdPollInfo {
pub(super) fn handle_count(&self) -> u32 {
self.handle_count
}
pub(super) fn events(&self) -> AfdPollMask {
self.handles[0].events
}
}
bitflags::bitflags! {
#[derive(Default)]
#[repr(transparent)]
pub(super) struct AfdPollMask: u32 {
const RECEIVE = 0x001;
const RECEIVE_EXPEDITED = 0x002;
const SEND = 0x004;
const DISCONNECT = 0x008;
const ABORT = 0x010;
const LOCAL_CLOSE = 0x020;
const ACCEPT = 0x080;
const CONNECT_FAIL = 0x100;
}
}
pub(super) trait HasAfdInfo {
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>>;
}
macro_rules! define_ntdll_import {
(
$(
$(#[$attr:meta])*
fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty;
)*
) => {
/// Imported functions from ntdll.dll.
#[allow(non_snake_case)]
pub(super) struct NtdllImports {
$(
$(#[$attr])*
$name: unsafe extern "system" fn($($arg_ty),*) -> $ret,
)*
}
#[allow(non_snake_case)]
impl NtdllImports {
unsafe fn load(ntdll: HINSTANCE) -> io::Result<Self> {
$(
let $name = {
const NAME: &str = concat!(stringify!($name), "\0");
let addr = GetProcAddress(ntdll, NAME.as_ptr() as *const _);
let addr = match addr {
Some(addr) => addr,
None => {
log::error!("Failed to load ntdll function {}", NAME);
return Err(io::Error::last_os_error());
},
};
transmute::<_, unsafe extern "system" fn($($arg_ty),*) -> $ret>(addr)
};
)*
Ok(Self {
$(
$name,
)*
})
}
$(
$(#[$attr])*
unsafe fn $name(&self, $($arg: $arg_ty),*) -> $ret {
(self.$name)($($arg),*)
}
)*
}
};
}
define_ntdll_import! {
/// Cancels an ongoing I/O operation.
fn NtCancelIoFileEx(
FileHandle: HANDLE,
IoRequestToCancel: *mut IO_STATUS_BLOCK,
IoStatusBlock: *mut IO_STATUS_BLOCK
) -> NTSTATUS;
/// Opens or creates a file handle.
#[allow(clippy::too_many_arguments)]
fn NtCreateFile(
FileHandle: *mut HANDLE,
DesiredAccess: u32,
ObjectAttributes: *mut OBJECT_ATTRIBUTES,
IoStatusBlock: *mut IO_STATUS_BLOCK,
AllocationSize: *mut i64,
FileAttributes: u32,
ShareAccess: u32,
CreateDisposition: u32,
CreateOptions: u32,
EaBuffer: *mut (),
EaLength: u32
) -> NTSTATUS;
/// Runs an I/O control on a file handle.
///
/// Practically equivalent to `ioctl`.
#[allow(clippy::too_many_arguments)]
fn NtDeviceIoControlFile(
FileHandle: HANDLE,
Event: HANDLE,
ApcRoutine: *mut (),
ApcContext: *mut (),
IoStatusBlock: *mut IO_STATUS_BLOCK,
IoControlCode: u32,
InputBuffer: *mut (),
InputBufferLength: u32,
OutputBuffer: *mut (),
OutputBufferLength: u32
) -> NTSTATUS;
/// Converts `NTSTATUS` to a DOS error code.
fn RtlNtStatusToDosError(
Status: NTSTATUS
) -> u32;
}
impl NtdllImports {
fn get() -> io::Result<&'static Self> {
macro_rules! s {
($e:expr) => {{
$e as u16
}};
}
// ntdll.dll
static NTDLL_NAME: &[u16] = &[
s!('n'),
s!('t'),
s!('d'),
s!('l'),
s!('l'),
s!('.'),
s!('d'),
s!('l'),
s!('l'),
s!('\0'),
];
static NTDLL_IMPORTS: OnceCell<io::Result<NtdllImports>> = OnceCell::new();
NTDLL_IMPORTS
.get_or_init(|| unsafe {
let ntdll = GetModuleHandleW(NTDLL_NAME.as_ptr() as *const _);
if ntdll == 0 {
log::error!("Failed to load ntdll.dll");
return Err(io::Error::last_os_error());
}
NtdllImports::load(ntdll)
})
.as_ref()
.map_err(|e| io::Error::from(e.kind()))
}
pub(super) fn force_load() -> io::Result<()> {
Self::get()?;
Ok(())
}
}
/// The handle to the AFD device.
pub(super) struct Afd<T> {
/// The handle to the AFD device.
handle: HANDLE,
/// We own `T`.
_marker: PhantomData<T>,
}
impl<T> fmt::Debug for Afd<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct WriteAsHex(HANDLE);
impl fmt::Debug for WriteAsHex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:010x}", self.0)
}
}
f.debug_struct("Afd")
.field("handle", &WriteAsHex(self.handle))
.finish()
}
}
impl<T> Drop for Afd<T> {
fn drop(&mut self) {
unsafe {
CloseHandle(self.handle);
}
}
}
impl<T> AsRawHandle for Afd<T> {
fn as_raw_handle(&self) -> RawHandle {
self.handle as _
}
}
impl<T: CompletionHandle> Afd<T>
where
T::Completion: AsIoStatusBlock + HasAfdInfo,
{
/// Create a new AFD device.
pub(super) fn new() -> io::Result<Self> {
macro_rules! s {
($e:expr) => {
($e) as u16
};
}
/// \Device\Afd\Smol
const AFD_NAME: &[u16] = &[
s!('\\'),
s!('D'),
s!('e'),
s!('v'),
s!('i'),
s!('c'),
s!('e'),
s!('\\'),
s!('A'),
s!('f'),
s!('d'),
s!('\\'),
s!('S'),
s!('m'),
s!('o'),
s!('l'),
s!('\0'),
];
// Set up device attributes.
let mut device_name = UNICODE_STRING {
Length: (AFD_NAME.len() * size_of::<u16>()) as u16,
MaximumLength: (AFD_NAME.len() * size_of::<u16>()) as u16,
Buffer: AFD_NAME.as_ptr() as *mut _,
};
let mut device_attributes = OBJECT_ATTRIBUTES {
Length: size_of::<OBJECT_ATTRIBUTES>() as u32,
RootDirectory: 0,
ObjectName: &mut device_name,
Attributes: 0,
SecurityDescriptor: ptr::null_mut(),
SecurityQualityOfService: ptr::null_mut(),
};
let mut handle = MaybeUninit::<HANDLE>::uninit();
let mut iosb = MaybeUninit::<IO_STATUS_BLOCK>::zeroed();
let ntdll = NtdllImports::get()?;
let result = unsafe {
ntdll.NtCreateFile(
handle.as_mut_ptr(),
SYNCHRONIZE,
&mut device_attributes,
iosb.as_mut_ptr(),
ptr::null_mut(),
0,
FILE_SHARE_READ | FILE_SHARE_WRITE,
FILE_OPEN,
0,
ptr::null_mut(),
0,
)
};
if result != STATUS_SUCCESS {
let real_code = unsafe { ntdll.RtlNtStatusToDosError(result) };
return Err(io::Error::from_raw_os_error(real_code as i32));
}
let handle = unsafe { handle.assume_init() };
Ok(Self {
handle,
_marker: PhantomData,
})
}
/// Begin polling with the provided handle.
pub(super) fn poll(
&self,
packet: T,
base_socket: RawSocket,
afd_events: AfdPollMask,
) -> io::Result<()> {
const IOCTL_AFD_POLL: u32 = 0x00012024;
// Lock the packet.
if !packet.get().try_lock() {
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"packet is already in use",
));
}
// Set up the AFD poll info.
let poll_info = unsafe {
let poll_info = Pin::into_inner_unchecked(packet.get().afd_info()).get();
// Initialize the AFD poll info.
(*poll_info).exclusive = false.into();
(*poll_info).handle_count = 1;
(*poll_info).timeout = std::i64::MAX;
(*poll_info).handles[0].handle = base_socket as HANDLE;
(*poll_info).handles[0].status = 0;
(*poll_info).handles[0].events = afd_events;
poll_info
};
let iosb = T::into_ptr(packet).cast::<IO_STATUS_BLOCK>();
// Set Status to pending
unsafe {
(*iosb).Anonymous.Status = STATUS_PENDING;
}
let ntdll = NtdllImports::get()?;
let result = unsafe {
ntdll.NtDeviceIoControlFile(
self.handle,
0,
ptr::null_mut(),
iosb.cast(),
iosb.cast(),
IOCTL_AFD_POLL,
poll_info.cast(),
size_of::<AfdPollInfo>() as u32,
poll_info.cast(),
size_of::<AfdPollInfo>() as u32,
)
};
match result {
STATUS_SUCCESS => Ok(()),
STATUS_PENDING => Err(io::ErrorKind::WouldBlock.into()),
status => {
let real_code = unsafe { ntdll.RtlNtStatusToDosError(status) };
Err(io::Error::from_raw_os_error(real_code as i32))
}
}
}
/// Cancel an ongoing poll operation.
///
/// # Safety
///
/// The poll operation must currently be in progress for this AFD.
pub(super) unsafe fn cancel(&self, packet: &T) -> io::Result<()> {
let ntdll = NtdllImports::get()?;
let result = {
// First, check if the packet is still in use.
let iosb = packet.as_ptr().cast::<IO_STATUS_BLOCK>();
if (*iosb).Anonymous.Status != STATUS_PENDING {
return Ok(());
}
// Cancel the packet.
let mut cancel_iosb = MaybeUninit::<IO_STATUS_BLOCK>::zeroed();
ntdll.NtCancelIoFileEx(self.handle, iosb, cancel_iosb.as_mut_ptr())
};
if result == STATUS_SUCCESS || result == STATUS_NOT_FOUND {
Ok(())
} else {
let real_code = ntdll.RtlNtStatusToDosError(result);
Err(io::Error::from_raw_os_error(real_code as i32))
}
}
}
/// A one-time initialization cell.
struct OnceCell<T> {
/// The value.
value: UnsafeCell<MaybeUninit<T>>,
/// The one-time initialization.
once: Once,
}
unsafe impl<T: Send + Sync> Send for OnceCell<T> {}
unsafe impl<T: Send + Sync> Sync for OnceCell<T> {}
impl<T> OnceCell<T> {
/// Creates a new `OnceCell`.
pub const fn new() -> Self {
OnceCell {
value: UnsafeCell::new(MaybeUninit::uninit()),
once: Once::new(),
}
}
/// Gets the value or initializes it.
pub fn get_or_init<F>(&self, f: F) -> &T
where
F: FnOnce() -> T,
{
self.once.call_once(|| unsafe {
let value = f();
*self.value.get() = MaybeUninit::new(value);
});
unsafe { &*self.value.get().cast() }
}
}
pin_project_lite::pin_project! {
/// An I/O status block paired with some auxillary data.
#[repr(C)]
pub(super) struct IoStatusBlock<T> {
// The I/O status block.
iosb: UnsafeCell<IO_STATUS_BLOCK>,
// Whether or not the block is in use.
in_use: AtomicBool,
// The auxillary data.
#[pin]
data: T,
// This block is not allowed to move.
#[pin]
_marker: PhantomPinned,
}
}
impl<T: fmt::Debug> fmt::Debug for IoStatusBlock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IoStatusBlock")
.field("iosb", &"..")
.field("in_use", &self.in_use)
.field("data", &self.data)
.finish()
}
}
impl<T> From<T> for IoStatusBlock<T> {
fn from(data: T) -> Self {
Self {
iosb: UnsafeCell::new(unsafe { std::mem::zeroed() }),
in_use: AtomicBool::new(false),
data,
_marker: PhantomPinned,
}
}
}
impl<T> IoStatusBlock<T> {
pub(super) fn iosb(self: Pin<&Self>) -> &UnsafeCell<IO_STATUS_BLOCK> {
self.project_ref().iosb
}
pub(super) fn data(self: Pin<&Self>) -> Pin<&T> {
self.project_ref().data
}
}
impl<T: HasAfdInfo> HasAfdInfo for IoStatusBlock<T> {
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>> {
self.project_ref().data.afd_info()
}
}
/// Can be transmuted to an I/O status block.
///
/// # Safety
///
/// A pointer to `T` must be able to be converted to a pointer to `IO_STATUS_BLOCK`
/// without any issues.
pub(super) unsafe trait AsIoStatusBlock {}
unsafe impl<T> AsIoStatusBlock for IoStatusBlock<T> {}
unsafe impl<T> Completion for IoStatusBlock<T> {
fn try_lock(self: Pin<&Self>) -> bool {
!self.in_use.swap(true, Ordering::SeqCst)
}
unsafe fn unlock(self: Pin<&Self>) {
self.in_use.store(false, Ordering::SeqCst);
}
}
/// Get the base socket associated with a socket.
pub(super) fn base_socket(sock: RawSocket) -> io::Result<RawSocket> {
// First, try the SIO_BASE_HANDLE ioctl.
let result = unsafe { try_socket_ioctl(sock, SIO_BASE_HANDLE) };
match result {
Ok(sock) => return Ok(sock),
Err(e) if e.kind() == io::ErrorKind::InvalidInput => return Err(e),
Err(_) => {}
}
// Some poorly coded LSPs may not handle SIO_BASE_HANDLE properly, but in some cases may
// handle SIO_BSP_HANDLE_POLL better. Try that.
let result = unsafe { try_socket_ioctl(sock, SIO_BSP_HANDLE_POLL)? };
if result == sock {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
// Try `SIO_BASE_HANDLE` again, in case the LSP fixed itself.
unsafe { try_socket_ioctl(result, SIO_BASE_HANDLE) }
}
/// Run an IOCTL on a socket and return a socket.
///
/// # Safety
///
/// The `ioctl` parameter must be a valid I/O control that returns a valid socket.
unsafe fn try_socket_ioctl(sock: RawSocket, ioctl: u32) -> io::Result<RawSocket> {
let mut out = MaybeUninit::<RawSocket>::uninit();
let mut bytes = 0u32;
let result = WSAIoctl(
sock as _,
ioctl,
ptr::null_mut(),
0,
out.as_mut_ptr().cast(),
size_of::<RawSocket>() as u32,
&mut bytes,
ptr::null_mut(),
None,
);
if result == SOCKET_ERROR {
return Err(io::Error::last_os_error());
}
Ok(out.assume_init())
}

834
src/iocp/mod.rs Normal file
View File

@ -0,0 +1,834 @@
//! Bindings to Windows I/O Completion Ports.
//!
//! I/O Completion Ports is a completion-based API rather than a polling-based API, like
//! epoll or kqueue. Therefore, we have to adapt the IOCP API to the crate's API.
//!
//! WinSock is powered by the Auxillary Function Driver (AFD) subsystem, which can be
//! accessed directly by using unstable `ntdll` functions. AFD exposes features that are not
//! available through the normal WinSock interface, such as IOCTL_AFD_POLL. This function is
//! similar to the exposed `WSAPoll` method. However, once the targeted socket is "ready",
//! a completion packet is queued to an I/O completion port.
//!
//! We take advantage of IOCTL_AFD_POLL to "translate" this crate's polling-based API
//! to the one Windows expects. When a device is added to the `Poller`, an IOCTL_AFD_POLL
//! operation is started and queued to the IOCP. To modify a currently registered device
//! (e.g. with `modify()` or `delete()`), the ongoing POLL is cancelled and then restarted
//! with new parameters. Whn the POLL eventually completes, the packet is posted to the IOCP.
//! From here it's a simple matter of using `GetQueuedCompletionStatusEx` to read the packets
//! from the IOCP and react accordingly. Notifying the poller is trivial, because we can
//! simply post a packet to the IOCP to wake it up.
//!
//! The main disadvantage of this strategy is that it relies on unstable Windows APIs.
//! However, as `libuv` (the backing I/O library for Node.JS) relies on the same unstable
//! AFD strategy, it is unlikely to be broken without plenty of advanced warning.
//!
//! Previously, this crate used the `wepoll` library for polling. `wepoll` uses a similar
//! AFD-based strategy for polling.
mod afd;
mod port;
use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock};
use port::{IoCompletionPort, OverlappedEntry};
use windows_sys::Win32::Foundation::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED};
use crate::{Event, PollMode};
use concurrent_queue::ConcurrentQueue;
use pin_project_lite::pin_project;
use std::cell::UnsafeCell;
use std::collections::hash_map::{Entry, HashMap};
use std::fmt;
use std::io;
use std::marker::PhantomPinned;
use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
use std::time::{Duration, Instant};
#[cfg(not(polling_no_io_safety))]
use std::os::windows::io::{AsHandle, BorrowedHandle};
/// Macro to lock and ignore lock poisoning.
macro_rules! lock {
($lock_result:expr) => {{
$lock_result.unwrap_or_else(|e| e.into_inner())
}};
}
/// Interface to I/O completion ports.
#[derive(Debug)]
pub(super) struct Poller {
/// The I/O completion port.
port: IoCompletionPort<Packet>,
/// List of currently active AFD instances.
///
/// Weak references are kept here so that the AFD handle is automatically dropped
/// when the last associated socket is dropped.
afd: Mutex<Vec<Weak<Afd<Packet>>>>,
/// The state of the sources registered with this poller.
sources: RwLock<HashMap<RawSocket, Packet>>,
/// Sockets with pending updates.
pending_updates: ConcurrentQueue<Packet>,
/// Are we currently polling?
polling: AtomicBool,
/// A list of completion packets.
packets: Mutex<Vec<OverlappedEntry<Packet>>>,
/// The packet used to notify the poller.
notifier: Packet,
}
unsafe impl Send for Poller {}
unsafe impl Sync for Poller {}
impl Poller {
/// Creates a new poller.
pub(super) fn new() -> io::Result<Self> {
// Make sure AFD is able to be used.
if let Err(e) = afd::NtdllImports::force_load() {
return Err(crate::unsupported_error(format!(
"Failed to initialize unstable Windows functions: {}\nThis usually only happens for old Windows or Wine.",
e
)));
}
// Create and destroy a single AFD to test if we support it.
Afd::<Packet>::new().map_err(|e| crate::unsupported_error(format!(
"Failed to initialize \\Device\\Afd: {}\nThis usually only happens for old Windows or Wine.",
e,
)))?;
let port = IoCompletionPort::new(0)?;
log::trace!("new: handle={:?}", &port);
Ok(Poller {
port,
afd: Mutex::new(vec![]),
sources: RwLock::new(HashMap::new()),
pending_updates: ConcurrentQueue::bounded(1024),
polling: AtomicBool::new(false),
packets: Mutex::new(Vec::with_capacity(1024)),
notifier: Arc::pin(
PacketInner::Wakeup {
_pinned: PhantomPinned,
}
.into(),
),
})
}
/// Whether this poller supports level-triggered events.
pub(super) fn supports_level(&self) -> bool {
true
}
/// Whether this poller supports edge-triggered events.
pub(super) fn supports_edge(&self) -> bool {
false
}
/// Add a new source to the poller.
pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> {
log::trace!(
"add: handle={:?}, sock={}, ev={:?}",
self.port,
socket,
interest
);
// We don't support edge-triggered events.
if matches!(mode, PollMode::Edge) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"edge-triggered events are not supported",
));
}
// Create a new packet.
let socket_state = {
let state = SocketState {
socket,
base_socket: base_socket(socket)?,
interest,
interest_error: true,
afd: self.afd_handle()?,
mode,
waiting_on_delete: false,
status: SocketStatus::Idle,
};
Arc::pin(IoStatusBlock::from(PacketInner::Socket {
packet: UnsafeCell::new(AfdPollInfo::default()),
socket: Mutex::new(state),
}))
};
// Keep track of the source in the poller.
{
let mut sources = lock!(self.sources.write());
match sources.entry(socket) {
Entry::Vacant(v) => {
v.insert(Pin::<Arc<_>>::clone(&socket_state));
}
Entry::Occupied(_) => {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}
}
}
// Update the packet.
self.update_packet(socket_state)
}
/// Update a source in the poller.
pub(super) fn modify(
&self,
socket: RawSocket,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
log::trace!(
"modify: handle={:?}, sock={}, ev={:?}",
self.port,
socket,
interest
);
// We don't support edge-triggered events.
if matches!(mode, PollMode::Edge) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"edge-triggered events are not supported",
));
}
// Get a reference to the source.
let source = {
let sources = lock!(self.sources.read());
sources
.get(&socket)
.cloned()
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
};
// Set the new event.
if source.as_ref().set_events(interest, mode) {
self.update_packet(source)?;
}
Ok(())
}
/// Delete a source from the poller.
pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> {
log::trace!("remove: handle={:?}, sock={}", self.port, socket);
// Get a reference to the source.
let source = {
let mut sources = lock!(self.sources.write());
match sources.remove(&socket) {
Some(s) => s,
None => {
// If the source has already been removed, then we can just return.
return Ok(());
}
}
};
// Indicate to the source that it is being deleted.
// This cancels any ongoing AFD_IOCTL_POLL operations.
source.begin_delete()
}
/// Wait for events.
pub(super) fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
log::trace!("wait: handle={:?}, timeout={:?}", self.port, timeout);
let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout));
let mut packets = lock!(self.packets.lock());
let mut notified = false;
events.packets.clear();
loop {
let mut new_events = 0;
// Indicate that we are now polling.
let was_polling = self.polling.swap(true, Ordering::SeqCst);
debug_assert!(!was_polling);
let guard = CallOnDrop(|| {
let was_polling = self.polling.swap(false, Ordering::SeqCst);
debug_assert!(was_polling);
});
// Process every entry in the queue before we start polling.
self.drain_update_queue(false)?;
// Get the time to wait for.
let timeout = deadline.map(|t| t.saturating_duration_since(Instant::now()));
// Wait for I/O events.
let len = self.port.wait(&mut packets, timeout)?;
log::trace!("new events: handle={:?}, len={}", self.port, len);
// We are no longer polling.
drop(guard);
// Process all of the events.
for entry in packets.drain(..) {
let packet = entry.into_packet();
// Feed the event into the packet.
match packet.feed_event(self)? {
FeedEventResult::NoEvent => {}
FeedEventResult::Event(event) => {
events.packets.push(event);
new_events += 1;
}
FeedEventResult::Notified => {
notified = true;
}
}
}
// Break if there was a notification or at least one event, or if deadline is reached.
let timeout_is_empty =
timeout.map_or(false, |t| t.as_secs() == 0 && t.subsec_nanos() == 0);
if notified || new_events > 0 || timeout_is_empty {
break;
}
log::trace!("wait: no events found, re-entering polling loop");
}
Ok(())
}
/// Notify this poller.
pub(super) fn notify(&self) -> io::Result<()> {
// Push the notify packet into the IOCP.
self.port.post(0, 0, self.notifier.clone())
}
/// Run an update on a packet.
fn update_packet(&self, mut packet: Packet) -> io::Result<()> {
loop {
// If we are currently polling, we need to update the packet immediately.
if self.polling.load(Ordering::Acquire) {
packet.update()?;
return Ok(());
}
// Try to queue the update.
match self.pending_updates.push(packet) {
Ok(()) => return Ok(()),
Err(p) => packet = p.into_inner(),
}
// If we failed to queue the update, we need to drain the queue first.
self.drain_update_queue(true)?;
}
}
/// Drain the update queue.
fn drain_update_queue(&self, limit: bool) -> io::Result<()> {
let max = if limit {
self.pending_updates.capacity().unwrap()
} else {
std::usize::MAX
};
// Only drain the queue's capacity, since this could in theory run forever.
core::iter::from_fn(|| self.pending_updates.pop().ok())
.take(max)
.try_for_each(|packet| packet.update())
}
/// Get a handle to the AFD reference.
fn afd_handle(&self) -> io::Result<Arc<Afd<Packet>>> {
const AFD_MAX_SIZE: usize = 32;
// Crawl the list and see if there are any existing AFD instances that we can use.
// Remove any unused AFD pointers.
let mut afd_handles = lock!(self.afd.lock());
let mut i = 0;
while i < afd_handles.len() {
// Get the reference count of the AFD instance.
let refcount = Weak::strong_count(&afd_handles[i]);
match refcount {
0 => {
// Prune the AFD pointer if it has no references.
afd_handles.swap_remove(i);
}
refcount if refcount >= AFD_MAX_SIZE => {
// Skip this one, since it is already at the maximum size.
i += 1;
}
_ => {
// We can use this AFD instance.
match afd_handles[i].upgrade() {
Some(afd) => return Ok(afd),
None => {
// The last socket dropped the AFD before we could acquire it.
// Prune the AFD pointer and continue.
afd_handles.swap_remove(i);
}
}
}
}
}
// No available handles, create a new AFD instance.
let afd = Arc::new(Afd::new()?);
// Register the AFD instance with the I/O completion port.
self.port.register(&*afd, true)?;
// Insert a weak pointer to the AFD instance into the list.
afd_handles.push(Arc::downgrade(&afd));
Ok(afd)
}
}
impl AsRawHandle for Poller {
fn as_raw_handle(&self) -> RawHandle {
self.port.as_raw_handle()
}
}
#[cfg(not(polling_no_io_safety))]
impl AsHandle for Poller {
fn as_handle(&self) -> BorrowedHandle<'_> {
unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) }
}
}
/// The container for events.
pub(super) struct Events {
/// List of IOCP packets.
packets: Vec<Event>,
}
unsafe impl Send for Events {}
impl Events {
/// Creates an empty list of events.
pub(super) fn new() -> Events {
Events {
packets: Vec::with_capacity(1024),
}
}
/// Iterate over I/O events.
pub(super) fn iter(&self) -> impl Iterator<Item = Event> + '_ {
self.packets.iter().copied()
}
}
/// The type of our completion packet.
type Packet = Pin<Arc<PacketUnwrapped>>;
type PacketUnwrapped = IoStatusBlock<PacketInner>;
pin_project! {
/// The inner type of the packet.
#[project_ref = PacketInnerProj]
#[project = PacketInnerProjMut]
enum PacketInner {
// A packet for a socket.
Socket {
// The AFD packet state.
#[pin]
packet: UnsafeCell<AfdPollInfo>,
// The socket state.
socket: Mutex<SocketState>
},
// A packet used to wake up the poller.
Wakeup { #[pin] _pinned: PhantomPinned },
}
}
impl fmt::Debug for PacketInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Wakeup { .. } => f.write_str("Wakeup { .. }"),
Self::Socket { socket, .. } => f
.debug_struct("Socket")
.field("packet", &"..")
.field("socket", socket)
.finish(),
}
}
}
impl HasAfdInfo for PacketInner {
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>> {
match self.project_ref() {
PacketInnerProj::Socket { packet, .. } => packet,
PacketInnerProj::Wakeup { .. } => unreachable!(),
}
}
}
impl PacketUnwrapped {
/// Set the new events that this socket is waiting on.
///
/// Returns `true` if we need to be updated.
fn set_events(self: Pin<&Self>, interest: Event, mode: PollMode) -> bool {
let mut socket = match self.socket_state() {
Some(s) => s,
None => return false,
};
socket.interest = interest;
socket.mode = mode;
socket.interest_error = true;
match socket.status {
SocketStatus::Polling { readable, writable } => {
(interest.readable && !readable) || (interest.writable && !writable)
}
_ => true,
}
}
/// Update the socket and install the new status in AFD.
fn update(self: Pin<Arc<Self>>) -> io::Result<()> {
let mut socket = match self.as_ref().socket_state() {
Some(s) => s,
None => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")),
};
// If we are waiting on a delete, just return, dropping the packet.
if socket.waiting_on_delete {
return Ok(());
}
// Check the current status.
match socket.status {
SocketStatus::Polling { readable, writable } => {
// If we need to poll for events aside from what we are currently polling, we need
// to update the packet. Cancel the ongoing poll.
if (socket.interest.readable && !readable)
|| (socket.interest.writable && !writable)
{
return self.cancel(socket);
}
// All events that we are currently waiting on are accounted for.
Ok(())
}
SocketStatus::Cancelled => {
// The ongoing operation was cancelled, and we're still waiting for it to return.
// For now, wait until the top-level loop calls feed_event().
Ok(())
}
SocketStatus::Idle => {
// Start a new poll.
let result = socket.afd.poll(
self.clone(),
socket.base_socket,
event_to_afd_mask(
socket.interest.readable,
socket.interest.writable,
socket.interest_error,
),
);
match result {
Ok(()) => {}
Err(err)
if err.raw_os_error() == Some(ERROR_IO_PENDING as i32)
|| err.kind() == io::ErrorKind::WouldBlock =>
{
// The operation is pending.
}
Err(err) if err.raw_os_error() == Some(ERROR_INVALID_HANDLE as i32) => {
// The socket was closed. We need to delete it.
// This should happen after we drop it here.
}
Err(err) => return Err(err),
}
// We are now polling for the current events.
socket.status = SocketStatus::Polling {
readable: socket.interest.readable,
writable: socket.interest.writable,
};
Ok(())
}
}
}
/// This socket state was notified; see if we need to update it.
fn feed_event(self: Pin<Arc<Self>>, poller: &Poller) -> io::Result<FeedEventResult> {
let inner = self.as_ref().data().project_ref();
let (afd_info, socket) = match inner {
PacketInnerProj::Socket { packet, socket } => (packet, socket),
PacketInnerProj::Wakeup { .. } => {
// The poller was notified.
return Ok(FeedEventResult::Notified);
}
};
let mut socket_state = lock!(socket.lock());
let mut event = Event::none(socket_state.interest.key);
// Put ourselves into the idle state.
socket_state.status = SocketStatus::Idle;
// If we are waiting to be deleted, just return and let the drop handler do their thing.
if socket_state.waiting_on_delete {
return Ok(FeedEventResult::NoEvent);
}
unsafe {
// SAFETY: The packet is not in transit.
let iosb = &mut *self.as_ref().iosb().get();
// Check the status.
match iosb.Anonymous.Status {
STATUS_CANCELLED => {
// Poll request was cancelled.
}
status if status < 0 => {
// There was an error, so we signal both ends.
event.readable = true;
event.writable = true;
}
_ => {
// Check in on the AFD data.
let afd_data = &*afd_info.get();
if afd_data.handle_count() >= 1 {
let events = afd_data.events();
// If we closed the socket, remove it from being polled.
if events.contains(AfdPollMask::LOCAL_CLOSE) {
let source = lock!(poller.sources.write())
.remove(&socket_state.socket)
.unwrap();
return source.begin_delete().map(|()| FeedEventResult::NoEvent);
}
// Report socket-related events.
let (readable, writable) = afd_mask_to_event(events);
event.readable = readable;
event.writable = writable;
}
}
}
}
// Filter out events that the user didn't ask for.
event.readable &= socket_state.interest.readable;
event.writable &= socket_state.interest.writable;
// If this event doesn't have anything that interests us, don't return or
// update the oneshot state.
let return_value = if event.readable || event.writable {
// If we are in oneshot mode, remove the interest.
if matches!(socket_state.mode, PollMode::Oneshot) {
socket_state.interest = Event::none(socket_state.interest.key);
socket_state.interest_error = false;
}
FeedEventResult::Event(event)
} else {
FeedEventResult::NoEvent
};
// Put ourselves in the update queue.
drop(socket_state);
poller.update_packet(self)?;
// Return the event.
Ok(return_value)
}
/// Begin deleting this socket.
fn begin_delete(self: Pin<Arc<Self>>) -> io::Result<()> {
// If we aren't already being deleted, start deleting.
let mut socket = self
.as_ref()
.socket_state()
.expect("can't delete packet that doesn't belong to a socket");
if !socket.waiting_on_delete {
socket.waiting_on_delete = true;
if matches!(socket.status, SocketStatus::Polling { .. }) {
// Cancel the ongoing poll.
self.cancel(socket)?;
}
}
// Either drop it now or wait for it to be dropped later.
Ok(())
}
fn cancel(self: &Pin<Arc<Self>>, mut socket: MutexGuard<'_, SocketState>) -> io::Result<()> {
assert!(matches!(socket.status, SocketStatus::Polling { .. }));
// Send the cancel request.
unsafe {
socket.afd.cancel(self)?;
}
// Move state to cancelled.
socket.status = SocketStatus::Cancelled;
Ok(())
}
fn socket_state(self: Pin<&Self>) -> Option<MutexGuard<'_, SocketState>> {
let inner = self.data().project_ref();
let state = match inner {
PacketInnerProj::Wakeup { .. } => return None,
PacketInnerProj::Socket { socket, .. } => socket,
};
Some(lock!(state.lock()))
}
}
/// Per-socket state.
#[derive(Debug)]
struct SocketState {
/// The raw socket handle.
socket: RawSocket,
/// The base socket handle.
base_socket: RawSocket,
/// The event that this socket is currently waiting on.
interest: Event,
/// Whether to listen for error events.
interest_error: bool,
/// The current poll mode.
mode: PollMode,
/// The AFD instance that this socket is registered with.
afd: Arc<Afd<Packet>>,
/// Whether this socket is waiting to be deleted.
waiting_on_delete: bool,
/// The current status of the socket.
status: SocketStatus,
}
/// The mode that a socket can be in.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum SocketStatus {
/// We are currently not polling.
Idle,
/// We are currently polling these events.
Polling {
/// We are currently polling for readable events.
readable: bool,
/// We are currently polling for writable events.
writable: bool,
},
/// The last poll operation was cancelled, and we're waiting for it to
/// complete.
Cancelled,
}
/// The result of calling `feed_event`.
#[derive(Debug)]
enum FeedEventResult {
/// No event was yielded.
NoEvent,
/// An event was yielded.
Event(Event),
/// The poller has been notified.
Notified,
}
fn event_to_afd_mask(readable: bool, writable: bool, error: bool) -> afd::AfdPollMask {
use afd::AfdPollMask as AfdPoll;
let mut mask = AfdPoll::empty();
if error || readable || writable {
mask |= AfdPoll::ABORT | AfdPoll::CONNECT_FAIL;
}
if readable {
mask |=
AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED;
}
if writable {
mask |= AfdPoll::SEND;
}
mask
}
fn afd_mask_to_event(mask: afd::AfdPollMask) -> (bool, bool) {
use afd::AfdPollMask as AfdPoll;
let mut readable = false;
let mut writable = false;
if mask.intersects(
AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED,
) {
readable = true;
}
if mask.intersects(AfdPoll::SEND) {
writable = true;
}
if mask.intersects(AfdPoll::ABORT | AfdPoll::CONNECT_FAIL) {
readable = true;
writable = true;
}
(readable, writable)
}
struct CallOnDrop<F: FnMut()>(F);
impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}

327
src/iocp/port.rs Normal file
View File

@ -0,0 +1,327 @@
//! A safe wrapper around the Windows I/O API.
use std::convert::{TryFrom, TryInto};
use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::os::windows::io::{AsRawHandle, RawHandle};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use windows_sys::Win32::Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE};
use windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes;
use windows_sys::Win32::System::WindowsProgramming::{FILE_SKIP_SET_EVENT_ON_HANDLE, INFINITE};
use windows_sys::Win32::System::IO::{
CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, OVERLAPPED,
OVERLAPPED_ENTRY,
};
/// A completion block which can be used with I/O completion ports.
///
/// # Safety
///
/// This must be a valid completion block.
pub(super) unsafe trait Completion {
/// Signal to the completion block that we are about to start an operation.
fn try_lock(self: Pin<&Self>) -> bool;
/// Unlock the completion block.
unsafe fn unlock(self: Pin<&Self>);
}
/// The pointer to a completion block.
///
/// # Safety
///
/// This must be a valid completion block.
pub(super) unsafe trait CompletionHandle: Deref + Sized {
/// Type of the completion block.
type Completion: Completion;
/// Get a pointer to the completion block.
///
/// The pointer is pinned since the underlying object should not be moved
/// after creation. This prevents it from being invalidated while it's
/// used in an overlapped operation.
fn get(&self) -> Pin<&Self::Completion>;
/// Convert this block into a pointer that can be passed as `*mut OVERLAPPED`.
fn into_ptr(this: Self) -> *mut OVERLAPPED;
/// Convert a pointer that was passed as `*mut OVERLAPPED` into a pointer to this block.
///
/// # Safety
///
/// This must be a valid pointer to a completion block.
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self;
/// Convert to a pointer without losing ownership.
fn as_ptr(&self) -> *mut OVERLAPPED;
}
unsafe impl<'a, T: Completion> CompletionHandle for Pin<&'a T> {
type Completion = T;
fn get(&self) -> Pin<&Self::Completion> {
*self
}
fn into_ptr(this: Self) -> *mut OVERLAPPED {
unsafe { Pin::into_inner_unchecked(this) as *const T as *mut OVERLAPPED }
}
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self {
Pin::new_unchecked(&*(ptr as *const T))
}
fn as_ptr(&self) -> *mut OVERLAPPED {
self.get_ref() as *const T as *mut OVERLAPPED
}
}
unsafe impl<T: Completion> CompletionHandle for Pin<Arc<T>> {
type Completion = T;
fn get(&self) -> Pin<&Self::Completion> {
self.as_ref()
}
fn into_ptr(this: Self) -> *mut OVERLAPPED {
unsafe { Arc::into_raw(Pin::into_inner_unchecked(this)) as *const T as *mut OVERLAPPED }
}
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self {
Pin::new_unchecked(Arc::from_raw(ptr as *const T))
}
fn as_ptr(&self) -> *mut OVERLAPPED {
self.as_ref().get_ref() as *const T as *mut OVERLAPPED
}
}
/// A handle to the I/O completion port.
pub(super) struct IoCompletionPort<T> {
/// The underlying handle.
handle: HANDLE,
/// We own the status block.
_marker: PhantomData<T>,
}
impl<T> Drop for IoCompletionPort<T> {
fn drop(&mut self) {
unsafe {
CloseHandle(self.handle);
}
}
}
impl<T> AsRawHandle for IoCompletionPort<T> {
fn as_raw_handle(&self) -> RawHandle {
self.handle as _
}
}
impl<T> fmt::Debug for IoCompletionPort<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct WriteAsHex(HANDLE);
impl fmt::Debug for WriteAsHex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:010x}", self.0)
}
}
f.debug_struct("IoCompletionPort")
.field("handle", &WriteAsHex(self.handle))
.finish()
}
}
impl<T: CompletionHandle> IoCompletionPort<T> {
/// Create a new I/O completion port.
pub(super) fn new(threads: usize) -> io::Result<Self> {
let handle = unsafe {
CreateIoCompletionPort(
INVALID_HANDLE_VALUE,
0,
0,
threads.try_into().expect("too many threads"),
)
};
if handle == 0 {
Err(io::Error::last_os_error())
} else {
Ok(Self {
handle,
_marker: PhantomData,
})
}
}
/// Register a handle with this I/O completion port.
pub(super) fn register(
&self,
handle: &impl AsRawHandle, // TODO change to AsHandle
skip_set_event_on_handle: bool,
) -> io::Result<()> {
let handle = handle.as_raw_handle();
let result =
unsafe { CreateIoCompletionPort(handle as _, self.handle, handle as usize, 0) };
if result == 0 {
return Err(io::Error::last_os_error());
}
if skip_set_event_on_handle {
// Set the skip event on handle.
let result = unsafe {
SetFileCompletionNotificationModes(handle as _, FILE_SKIP_SET_EVENT_ON_HANDLE as _)
};
if result == 0 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}
/// Post a completion packet to this port.
pub(super) fn post(&self, bytes_transferred: usize, id: usize, packet: T) -> io::Result<()> {
let result = unsafe {
PostQueuedCompletionStatus(
self.handle,
bytes_transferred
.try_into()
.expect("too many bytes transferred"),
id,
T::into_ptr(packet),
)
};
if result == 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
/// Wait for completion packets to arrive.
pub(super) fn wait(
&self,
packets: &mut Vec<OverlappedEntry<T>>,
timeout: Option<Duration>,
) -> io::Result<usize> {
// Drop the current packets.
packets.clear();
let mut count = MaybeUninit::<u32>::uninit();
let timeout = timeout.map_or(INFINITE, dur2timeout);
let result = unsafe {
GetQueuedCompletionStatusEx(
self.handle,
packets.as_mut_ptr() as _,
packets.capacity().try_into().expect("too many packets"),
count.as_mut_ptr(),
timeout,
0,
)
};
if result == 0 {
let io_error = io::Error::last_os_error();
if io_error.kind() == io::ErrorKind::TimedOut {
Ok(0)
} else {
Err(io_error)
}
} else {
let count = unsafe { count.assume_init() };
unsafe {
packets.set_len(count as _);
}
Ok(count as _)
}
}
}
/// An `OVERLAPPED_ENTRY` resulting from an I/O completion port.
#[repr(transparent)]
pub(super) struct OverlappedEntry<T: CompletionHandle> {
/// The underlying entry.
entry: OVERLAPPED_ENTRY,
/// We own the status block.
_marker: PhantomData<T>,
}
impl<T: CompletionHandle> fmt::Debug for OverlappedEntry<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("OverlappedEntry { .. }")
}
}
impl<T: CompletionHandle> OverlappedEntry<T> {
/// Convert into the completion packet.
pub(super) fn into_packet(self) -> T {
let packet = unsafe { self.packet() };
std::mem::forget(self);
packet
}
/// Get the packet reference that this entry refers to.
///
/// # Safety
///
/// This function should only be called once, since it moves
/// out the `T` from the `OVERLAPPED_ENTRY`.
unsafe fn packet(&self) -> T {
let packet = T::from_ptr(self.entry.lpOverlapped);
packet.get().unlock();
packet
}
}
impl<T: CompletionHandle> Drop for OverlappedEntry<T> {
fn drop(&mut self) {
drop(unsafe { self.packet() });
}
}
// Implementation taken from https://github.com/rust-lang/rust/blob/db5476571d9b27c862b95c1e64764b0ac8980e23/src/libstd/sys/windows/mod.rs
fn dur2timeout(dur: Duration) -> u32 {
// Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
// timeouts in windows APIs are typically u32 milliseconds. To translate, we
// have two pieces to take care of:
//
// * Nanosecond precision is rounded up
// * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE
// (never time out).
dur.as_secs()
.checked_mul(1000)
.and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000))
.and_then(|ms| {
if dur.subsec_nanos() % 1_000_000 > 0 {
ms.checked_add(1)
} else {
Some(ms)
}
})
.and_then(|x| u32::try_from(x).ok())
.unwrap_or(INFINITE)
}
struct CallOnDrop<F: FnMut()>(F);
impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}

View File

@ -1,4 +1,4 @@
//! Portable interface to epoll, kqueue, event ports, and wepoll. //! Portable interface to epoll, kqueue, event ports, and IOCP.
//! //!
//! Supported platforms: //! Supported platforms:
//! - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android //! - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android
@ -6,7 +6,7 @@
//! DragonFly BSD //! DragonFly BSD
//! - [event ports](https://illumos.org/man/port_create): illumos, Solaris //! - [event ports](https://illumos.org/man/port_create): illumos, Solaris
//! - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems //! - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems
//! - [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+) //! - [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+)
//! //!
//! By default, polling is done in oneshot mode, which means interest in I/O events needs to //! By default, polling is done in oneshot mode, which means interest in I/O events needs to
//! be re-enabled after an event is delivered if we're interested in the next event of the same //! be re-enabled after an event is delivered if we're interested in the next event of the same
@ -113,8 +113,8 @@ cfg_if! {
mod poll; mod poll;
use poll as sys; use poll as sys;
} else if #[cfg(target_os = "windows")] { } else if #[cfg(target_os = "windows")] {
mod wepoll; mod iocp;
use wepoll as sys; use iocp as sys;
} else { } else {
compile_error!("polling does not support this target OS"); compile_error!("polling does not support this target OS");
} }

View File

@ -1,254 +0,0 @@
//! Bindings to wepoll (Windows).
use std::convert::TryInto;
use std::io;
use std::os::raw::c_int;
use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket};
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
#[cfg(not(polling_no_io_safety))]
use std::os::windows::io::{AsHandle, BorrowedHandle};
use wepoll_ffi as we;
use crate::{Event, PollMode};
/// Calls a wepoll function and results in `io::Result`.
macro_rules! wepoll {
($fn:ident $args:tt) => {{
let res = unsafe { we::$fn $args };
if res == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(res)
}
}};
}
/// Interface to wepoll.
#[derive(Debug)]
pub struct Poller {
handle: we::HANDLE,
notified: AtomicBool,
}
unsafe impl Send for Poller {}
unsafe impl Sync for Poller {}
impl Poller {
/// Creates a new poller.
pub fn new() -> io::Result<Poller> {
let handle = unsafe { we::epoll_create1(0) };
if handle.is_null() {
return Err(crate::unsupported_error(
format!(
"Failed to initialize Wepoll: {}\nThis usually only happens for old Windows or Wine.",
io::Error::last_os_error()
)
));
}
let notified = AtomicBool::new(false);
log::trace!("new: handle={:?}", handle);
Ok(Poller { handle, notified })
}
/// Whether this poller supports level-triggered events.
pub fn supports_level(&self) -> bool {
true
}
/// Whether this poller supports edge-triggered events.
pub fn supports_edge(&self) -> bool {
false
}
/// Adds a socket.
pub fn add(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> {
log::trace!("add: handle={:?}, sock={}, ev={:?}", self.handle, sock, ev);
self.ctl(we::EPOLL_CTL_ADD, sock, Some((ev, mode)))
}
/// Modifies a socket.
pub fn modify(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> {
log::trace!(
"modify: handle={:?}, sock={}, ev={:?}",
self.handle,
sock,
ev
);
self.ctl(we::EPOLL_CTL_MOD, sock, Some((ev, mode)))
}
/// Deletes a socket.
pub fn delete(&self, sock: RawSocket) -> io::Result<()> {
log::trace!("remove: handle={:?}, sock={}", self.handle, sock);
self.ctl(we::EPOLL_CTL_DEL, sock, None)
}
/// Waits for I/O events with an optional timeout.
///
/// Returns the number of processed I/O events.
///
/// If a notification occurs, this method will return but the notification event will not be
/// included in the `events` list nor contribute to the returned count.
pub fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
log::trace!("wait: handle={:?}, timeout={:?}", self.handle, timeout);
let deadline = timeout.and_then(|t| Instant::now().checked_add(t));
loop {
// Convert the timeout to milliseconds.
let timeout_ms = match deadline.map(|d| d.saturating_duration_since(Instant::now())) {
None => -1,
Some(t) => {
// Round up to a whole millisecond.
let mut ms = t.as_millis().try_into().unwrap_or(std::u64::MAX);
if Duration::from_millis(ms) < t {
ms = ms.saturating_add(1);
}
ms.try_into().unwrap_or(std::i32::MAX)
}
};
// Wait for I/O events.
events.len = wepoll!(epoll_wait(
self.handle,
events.list.as_mut_ptr(),
events.list.len() as c_int,
timeout_ms,
))? as usize;
log::trace!("new events: handle={:?}, len={}", self.handle, events.len);
// Break if there was a notification or at least one event, or if deadline is reached.
if self.notified.swap(false, Ordering::SeqCst) || events.len > 0 || timeout_ms == 0 {
break;
}
}
Ok(())
}
/// Sends a notification to wake up the current or next `wait()` call.
pub fn notify(&self) -> io::Result<()> {
log::trace!("notify: handle={:?}", self.handle);
if self
.notified
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
unsafe {
// This call errors if a notification has already been posted, but that's okay - we
// can just ignore the error.
//
// The original wepoll does not support notifications triggered this way, which is
// why wepoll-sys includes a small patch to support them.
windows_sys::Win32::System::IO::PostQueuedCompletionStatus(
self.handle as _,
0,
0,
ptr::null_mut(),
);
}
}
Ok(())
}
/// Passes arguments to `epoll_ctl`.
fn ctl(&self, op: u32, sock: RawSocket, ev: Option<(Event, PollMode)>) -> io::Result<()> {
let mut ev = ev
.map(|(ev, mode)| {
let mut flags = match mode {
PollMode::Level => 0,
PollMode::Oneshot => we::EPOLLONESHOT,
PollMode::Edge => {
return Err(crate::unsupported_error(
"edge-triggered events are not supported with wepoll",
));
}
};
if ev.readable {
flags |= READ_FLAGS;
}
if ev.writable {
flags |= WRITE_FLAGS;
}
Ok(we::epoll_event {
events: flags as u32,
data: we::epoll_data {
u64_: ev.key as u64,
},
})
})
.transpose()?;
wepoll!(epoll_ctl(
self.handle,
op as c_int,
sock as we::SOCKET,
ev.as_mut()
.map(|ev| ev as *mut we::epoll_event)
.unwrap_or(ptr::null_mut()),
))?;
Ok(())
}
}
impl AsRawHandle for Poller {
fn as_raw_handle(&self) -> RawHandle {
self.handle as RawHandle
}
}
#[cfg(not(polling_no_io_safety))]
impl AsHandle for Poller {
fn as_handle(&self) -> BorrowedHandle<'_> {
// SAFETY: lifetime is bound by "self"
unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) }
}
}
impl Drop for Poller {
fn drop(&mut self) {
log::trace!("drop: handle={:?}", self.handle);
unsafe {
we::epoll_close(self.handle);
}
}
}
/// Wepoll flags for all possible readability events.
const READ_FLAGS: u32 = we::EPOLLIN | we::EPOLLRDHUP | we::EPOLLHUP | we::EPOLLERR | we::EPOLLPRI;
/// Wepoll flags for all possible writability events.
const WRITE_FLAGS: u32 = we::EPOLLOUT | we::EPOLLHUP | we::EPOLLERR;
/// A list of reported I/O events.
pub struct Events {
list: Box<[we::epoll_event; 1024]>,
len: usize,
}
unsafe impl Send for Events {}
impl Events {
/// Creates an empty list.
pub fn new() -> Events {
let ev = we::epoll_event {
events: 0,
data: we::epoll_data { u64_: 0 },
};
let list = Box::new([ev; 1024]);
Events { list, len: 0 }
}
/// Iterates over I/O events.
pub fn iter(&self) -> impl Iterator<Item = Event> + '_ {
self.list[..self.len].iter().map(|ev| Event {
key: unsafe { ev.data.u64_ } as usize,
readable: (ev.events & READ_FLAGS) != 0,
writable: (ev.events & WRITE_FLAGS) != 0,
})
}
}

View File

@ -43,7 +43,7 @@ fn concurrent_modify() -> io::Result<()> {
Parallel::new() Parallel::new()
.add(|| { .add(|| {
poller.wait(&mut events, None)?; poller.wait(&mut events, Some(Duration::from_secs(10)))?;
Ok(()) Ok(())
}) })
.add(|| { .add(|| {

38
tests/io.rs Normal file
View File

@ -0,0 +1,38 @@
use polling::{Event, Poller};
use std::io::{self, Write};
use std::net::{TcpListener, TcpStream};
use std::time::Duration;
#[test]
fn basic_io() {
let poller = Poller::new().unwrap();
let (read, mut write) = tcp_pair().unwrap();
poller.add(&read, Event::readable(1)).unwrap();
// Nothing should be available at first.
let mut events = vec![];
assert_eq!(
poller
.wait(&mut events, Some(Duration::from_secs(0)))
.unwrap(),
0
);
assert!(events.is_empty());
// After a write, the event should be available now.
write.write_all(&[1]).unwrap();
assert_eq!(
poller
.wait(&mut events, Some(Duration::from_secs(1)))
.unwrap(),
1
);
assert_eq!(&*events, &[Event::readable(1)]);
}
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let a = TcpStream::connect(listener.local_addr()?)?;
let (b, _) = listener.accept()?;
Ok((a, b))
}

View File

@ -18,7 +18,7 @@ fn below_ms() -> io::Result<()> {
let elapsed = now.elapsed(); let elapsed = now.elapsed();
assert_eq!(n, 0); assert_eq!(n, 0);
assert!(elapsed >= dur); assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur);
lowest = lowest.min(elapsed); lowest = lowest.min(elapsed);
} }
@ -54,7 +54,7 @@ fn above_ms() -> io::Result<()> {
let elapsed = now.elapsed(); let elapsed = now.elapsed();
assert_eq!(n, 0); assert_eq!(n, 0);
assert!(elapsed >= dur); assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur);
lowest = lowest.min(elapsed); lowest = lowest.min(elapsed);
} }