m(windows): Reimplement Wepoll in Rust (#88)

Reimplements the C-based wepoll backend in Rust, using some handwritten code. This PR also implements bindings to the I/O Completion Ports and \Device\Afd APIs. For more information on the latter, see my blog post on the subject: https://notgull.github.io/device-afd/

Note that the IOCP API is wrapped using a `Pin`-oriented "CompletionHandle" system that is relatively brittle. This should be replaced with a better model when one becomes available.
This commit is contained in:
John Nunley 2023-03-05 16:25:25 -08:00 committed by GitHub
parent e85331c437
commit 24900fb662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1826 additions and 267 deletions

View File

@ -7,10 +7,10 @@ version = "2.5.2"
authors = ["Stjepan Glavina <stjepang@gmail.com>"]
edition = "2018"
rust-version = "1.47"
description = "Portable interface to epoll, kqueue, event ports, and wepoll"
description = "Portable interface to epoll, kqueue, event ports, and IOCP"
license = "Apache-2.0 OR MIT"
repository = "https://github.com/smol-rs/polling"
keywords = ["mio", "epoll", "kqueue", "iocp", "wepoll"]
keywords = ["mio", "epoll", "kqueue", "iocp"]
categories = ["asynchronous", "network-programming", "os"]
exclude = ["/.*"]
@ -32,13 +32,19 @@ autocfg = "1"
libc = "0.2.77"
[target.'cfg(windows)'.dependencies]
wepoll-ffi = { version = "0.1.2", features = ["null-overlapped-wakeups-patch"] }
bitflags = "1.3.2"
concurrent-queue = "2.1.0"
pin-project-lite = "0.2.9"
[target.'cfg(windows)'.dependencies.windows-sys]
version = "0.45"
features = [
"Win32_Networking_WinSock",
"Win32_System_IO",
"Win32_Foundation"
"Win32_System_LibraryLoader",
"Win32_System_WindowsProgramming",
"Win32_Storage_FileSystem",
"Win32_Foundation",
]
[dev-dependencies]

View File

@ -9,7 +9,7 @@ https://crates.io/crates/polling)
[![Documentation](https://docs.rs/polling/badge.svg)](
https://docs.rs/polling)
Portable interface to epoll, kqueue, event ports, and wepoll.
Portable interface to epoll, kqueue, event ports, and IOCP.
Supported platforms:
- [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android
@ -17,7 +17,7 @@ Supported platforms:
DragonFly BSD
- [event ports](https://illumos.org/man/port_create): illumos, Solaris
- [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems
- [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+)
- [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+)
Polling is done in oneshot mode, which means interest in I/O events needs to be reset after
an event is delivered if we're interested in the next event of the same kind.

608
src/iocp/afd.rs Normal file
View File

@ -0,0 +1,608 @@
//! Safe wrapper around \Device\Afd
use super::port::{Completion, CompletionHandle};
use std::cell::UnsafeCell;
use std::fmt;
use std::io;
use std::marker::{PhantomData, PhantomPinned};
use std::mem::{size_of, transmute, MaybeUninit};
use std::os::windows::prelude::{AsRawHandle, RawHandle, RawSocket};
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Once;
use windows_sys::Win32::Foundation::{
CloseHandle, HANDLE, HINSTANCE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS,
UNICODE_STRING,
};
use windows_sys::Win32::Networking::WinSock::{
WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE_POLL, SOCKET_ERROR,
};
use windows_sys::Win32::Storage::FileSystem::{
FILE_OPEN, FILE_SHARE_READ, FILE_SHARE_WRITE, SYNCHRONIZE,
};
use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleW, GetProcAddress};
use windows_sys::Win32::System::WindowsProgramming::{IO_STATUS_BLOCK, OBJECT_ATTRIBUTES};
#[derive(Default)]
#[repr(C)]
pub(super) struct AfdPollInfo {
/// The timeout for this poll.
timeout: i64,
/// The number of handles being polled.
handle_count: u32,
/// Whether or not this poll is exclusive for this handle.
exclusive: u32,
/// The handles to poll.
handles: [AfdPollHandleInfo; 1],
}
#[derive(Default)]
#[repr(C)]
struct AfdPollHandleInfo {
/// The handle to poll.
handle: HANDLE,
/// The events to poll for.
events: AfdPollMask,
/// The status of the poll.
status: NTSTATUS,
}
impl AfdPollInfo {
pub(super) fn handle_count(&self) -> u32 {
self.handle_count
}
pub(super) fn events(&self) -> AfdPollMask {
self.handles[0].events
}
}
bitflags::bitflags! {
#[derive(Default)]
#[repr(transparent)]
pub(super) struct AfdPollMask: u32 {
const RECEIVE = 0x001;
const RECEIVE_EXPEDITED = 0x002;
const SEND = 0x004;
const DISCONNECT = 0x008;
const ABORT = 0x010;
const LOCAL_CLOSE = 0x020;
const ACCEPT = 0x080;
const CONNECT_FAIL = 0x100;
}
}
pub(super) trait HasAfdInfo {
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>>;
}
macro_rules! define_ntdll_import {
(
$(
$(#[$attr:meta])*
fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty;
)*
) => {
/// Imported functions from ntdll.dll.
#[allow(non_snake_case)]
pub(super) struct NtdllImports {
$(
$(#[$attr])*
$name: unsafe extern "system" fn($($arg_ty),*) -> $ret,
)*
}
#[allow(non_snake_case)]
impl NtdllImports {
unsafe fn load(ntdll: HINSTANCE) -> io::Result<Self> {
$(
let $name = {
const NAME: &str = concat!(stringify!($name), "\0");
let addr = GetProcAddress(ntdll, NAME.as_ptr() as *const _);
let addr = match addr {
Some(addr) => addr,
None => {
log::error!("Failed to load ntdll function {}", NAME);
return Err(io::Error::last_os_error());
},
};
transmute::<_, unsafe extern "system" fn($($arg_ty),*) -> $ret>(addr)
};
)*
Ok(Self {
$(
$name,
)*
})
}
$(
$(#[$attr])*
unsafe fn $name(&self, $($arg: $arg_ty),*) -> $ret {
(self.$name)($($arg),*)
}
)*
}
};
}
define_ntdll_import! {
/// Cancels an ongoing I/O operation.
fn NtCancelIoFileEx(
FileHandle: HANDLE,
IoRequestToCancel: *mut IO_STATUS_BLOCK,
IoStatusBlock: *mut IO_STATUS_BLOCK
) -> NTSTATUS;
/// Opens or creates a file handle.
#[allow(clippy::too_many_arguments)]
fn NtCreateFile(
FileHandle: *mut HANDLE,
DesiredAccess: u32,
ObjectAttributes: *mut OBJECT_ATTRIBUTES,
IoStatusBlock: *mut IO_STATUS_BLOCK,
AllocationSize: *mut i64,
FileAttributes: u32,
ShareAccess: u32,
CreateDisposition: u32,
CreateOptions: u32,
EaBuffer: *mut (),
EaLength: u32
) -> NTSTATUS;
/// Runs an I/O control on a file handle.
///
/// Practically equivalent to `ioctl`.
#[allow(clippy::too_many_arguments)]
fn NtDeviceIoControlFile(
FileHandle: HANDLE,
Event: HANDLE,
ApcRoutine: *mut (),
ApcContext: *mut (),
IoStatusBlock: *mut IO_STATUS_BLOCK,
IoControlCode: u32,
InputBuffer: *mut (),
InputBufferLength: u32,
OutputBuffer: *mut (),
OutputBufferLength: u32
) -> NTSTATUS;
/// Converts `NTSTATUS` to a DOS error code.
fn RtlNtStatusToDosError(
Status: NTSTATUS
) -> u32;
}
impl NtdllImports {
fn get() -> io::Result<&'static Self> {
macro_rules! s {
($e:expr) => {{
$e as u16
}};
}
// ntdll.dll
static NTDLL_NAME: &[u16] = &[
s!('n'),
s!('t'),
s!('d'),
s!('l'),
s!('l'),
s!('.'),
s!('d'),
s!('l'),
s!('l'),
s!('\0'),
];
static NTDLL_IMPORTS: OnceCell<io::Result<NtdllImports>> = OnceCell::new();
NTDLL_IMPORTS
.get_or_init(|| unsafe {
let ntdll = GetModuleHandleW(NTDLL_NAME.as_ptr() as *const _);
if ntdll == 0 {
log::error!("Failed to load ntdll.dll");
return Err(io::Error::last_os_error());
}
NtdllImports::load(ntdll)
})
.as_ref()
.map_err(|e| io::Error::from(e.kind()))
}
pub(super) fn force_load() -> io::Result<()> {
Self::get()?;
Ok(())
}
}
/// The handle to the AFD device.
pub(super) struct Afd<T> {
/// The handle to the AFD device.
handle: HANDLE,
/// We own `T`.
_marker: PhantomData<T>,
}
impl<T> fmt::Debug for Afd<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct WriteAsHex(HANDLE);
impl fmt::Debug for WriteAsHex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:010x}", self.0)
}
}
f.debug_struct("Afd")
.field("handle", &WriteAsHex(self.handle))
.finish()
}
}
impl<T> Drop for Afd<T> {
fn drop(&mut self) {
unsafe {
CloseHandle(self.handle);
}
}
}
impl<T> AsRawHandle for Afd<T> {
fn as_raw_handle(&self) -> RawHandle {
self.handle as _
}
}
impl<T: CompletionHandle> Afd<T>
where
T::Completion: AsIoStatusBlock + HasAfdInfo,
{
/// Create a new AFD device.
pub(super) fn new() -> io::Result<Self> {
macro_rules! s {
($e:expr) => {
($e) as u16
};
}
/// \Device\Afd\Smol
const AFD_NAME: &[u16] = &[
s!('\\'),
s!('D'),
s!('e'),
s!('v'),
s!('i'),
s!('c'),
s!('e'),
s!('\\'),
s!('A'),
s!('f'),
s!('d'),
s!('\\'),
s!('S'),
s!('m'),
s!('o'),
s!('l'),
s!('\0'),
];
// Set up device attributes.
let mut device_name = UNICODE_STRING {
Length: (AFD_NAME.len() * size_of::<u16>()) as u16,
MaximumLength: (AFD_NAME.len() * size_of::<u16>()) as u16,
Buffer: AFD_NAME.as_ptr() as *mut _,
};
let mut device_attributes = OBJECT_ATTRIBUTES {
Length: size_of::<OBJECT_ATTRIBUTES>() as u32,
RootDirectory: 0,
ObjectName: &mut device_name,
Attributes: 0,
SecurityDescriptor: ptr::null_mut(),
SecurityQualityOfService: ptr::null_mut(),
};
let mut handle = MaybeUninit::<HANDLE>::uninit();
let mut iosb = MaybeUninit::<IO_STATUS_BLOCK>::zeroed();
let ntdll = NtdllImports::get()?;
let result = unsafe {
ntdll.NtCreateFile(
handle.as_mut_ptr(),
SYNCHRONIZE,
&mut device_attributes,
iosb.as_mut_ptr(),
ptr::null_mut(),
0,
FILE_SHARE_READ | FILE_SHARE_WRITE,
FILE_OPEN,
0,
ptr::null_mut(),
0,
)
};
if result != STATUS_SUCCESS {
let real_code = unsafe { ntdll.RtlNtStatusToDosError(result) };
return Err(io::Error::from_raw_os_error(real_code as i32));
}
let handle = unsafe { handle.assume_init() };
Ok(Self {
handle,
_marker: PhantomData,
})
}
/// Begin polling with the provided handle.
pub(super) fn poll(
&self,
packet: T,
base_socket: RawSocket,
afd_events: AfdPollMask,
) -> io::Result<()> {
const IOCTL_AFD_POLL: u32 = 0x00012024;
// Lock the packet.
if !packet.get().try_lock() {
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"packet is already in use",
));
}
// Set up the AFD poll info.
let poll_info = unsafe {
let poll_info = Pin::into_inner_unchecked(packet.get().afd_info()).get();
// Initialize the AFD poll info.
(*poll_info).exclusive = false.into();
(*poll_info).handle_count = 1;
(*poll_info).timeout = std::i64::MAX;
(*poll_info).handles[0].handle = base_socket as HANDLE;
(*poll_info).handles[0].status = 0;
(*poll_info).handles[0].events = afd_events;
poll_info
};
let iosb = T::into_ptr(packet).cast::<IO_STATUS_BLOCK>();
// Set Status to pending
unsafe {
(*iosb).Anonymous.Status = STATUS_PENDING;
}
let ntdll = NtdllImports::get()?;
let result = unsafe {
ntdll.NtDeviceIoControlFile(
self.handle,
0,
ptr::null_mut(),
iosb.cast(),
iosb.cast(),
IOCTL_AFD_POLL,
poll_info.cast(),
size_of::<AfdPollInfo>() as u32,
poll_info.cast(),
size_of::<AfdPollInfo>() as u32,
)
};
match result {
STATUS_SUCCESS => Ok(()),
STATUS_PENDING => Err(io::ErrorKind::WouldBlock.into()),
status => {
let real_code = unsafe { ntdll.RtlNtStatusToDosError(status) };
Err(io::Error::from_raw_os_error(real_code as i32))
}
}
}
/// Cancel an ongoing poll operation.
///
/// # Safety
///
/// The poll operation must currently be in progress for this AFD.
pub(super) unsafe fn cancel(&self, packet: &T) -> io::Result<()> {
let ntdll = NtdllImports::get()?;
let result = {
// First, check if the packet is still in use.
let iosb = packet.as_ptr().cast::<IO_STATUS_BLOCK>();
if (*iosb).Anonymous.Status != STATUS_PENDING {
return Ok(());
}
// Cancel the packet.
let mut cancel_iosb = MaybeUninit::<IO_STATUS_BLOCK>::zeroed();
ntdll.NtCancelIoFileEx(self.handle, iosb, cancel_iosb.as_mut_ptr())
};
if result == STATUS_SUCCESS || result == STATUS_NOT_FOUND {
Ok(())
} else {
let real_code = ntdll.RtlNtStatusToDosError(result);
Err(io::Error::from_raw_os_error(real_code as i32))
}
}
}
/// A one-time initialization cell.
struct OnceCell<T> {
/// The value.
value: UnsafeCell<MaybeUninit<T>>,
/// The one-time initialization.
once: Once,
}
unsafe impl<T: Send + Sync> Send for OnceCell<T> {}
unsafe impl<T: Send + Sync> Sync for OnceCell<T> {}
impl<T> OnceCell<T> {
/// Creates a new `OnceCell`.
pub const fn new() -> Self {
OnceCell {
value: UnsafeCell::new(MaybeUninit::uninit()),
once: Once::new(),
}
}
/// Gets the value or initializes it.
pub fn get_or_init<F>(&self, f: F) -> &T
where
F: FnOnce() -> T,
{
self.once.call_once(|| unsafe {
let value = f();
*self.value.get() = MaybeUninit::new(value);
});
unsafe { &*self.value.get().cast() }
}
}
pin_project_lite::pin_project! {
/// An I/O status block paired with some auxillary data.
#[repr(C)]
pub(super) struct IoStatusBlock<T> {
// The I/O status block.
iosb: UnsafeCell<IO_STATUS_BLOCK>,
// Whether or not the block is in use.
in_use: AtomicBool,
// The auxillary data.
#[pin]
data: T,
// This block is not allowed to move.
#[pin]
_marker: PhantomPinned,
}
}
impl<T: fmt::Debug> fmt::Debug for IoStatusBlock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IoStatusBlock")
.field("iosb", &"..")
.field("in_use", &self.in_use)
.field("data", &self.data)
.finish()
}
}
impl<T> From<T> for IoStatusBlock<T> {
fn from(data: T) -> Self {
Self {
iosb: UnsafeCell::new(unsafe { std::mem::zeroed() }),
in_use: AtomicBool::new(false),
data,
_marker: PhantomPinned,
}
}
}
impl<T> IoStatusBlock<T> {
pub(super) fn iosb(self: Pin<&Self>) -> &UnsafeCell<IO_STATUS_BLOCK> {
self.project_ref().iosb
}
pub(super) fn data(self: Pin<&Self>) -> Pin<&T> {
self.project_ref().data
}
}
impl<T: HasAfdInfo> HasAfdInfo for IoStatusBlock<T> {
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>> {
self.project_ref().data.afd_info()
}
}
/// Can be transmuted to an I/O status block.
///
/// # Safety
///
/// A pointer to `T` must be able to be converted to a pointer to `IO_STATUS_BLOCK`
/// without any issues.
pub(super) unsafe trait AsIoStatusBlock {}
unsafe impl<T> AsIoStatusBlock for IoStatusBlock<T> {}
unsafe impl<T> Completion for IoStatusBlock<T> {
fn try_lock(self: Pin<&Self>) -> bool {
!self.in_use.swap(true, Ordering::SeqCst)
}
unsafe fn unlock(self: Pin<&Self>) {
self.in_use.store(false, Ordering::SeqCst);
}
}
/// Get the base socket associated with a socket.
pub(super) fn base_socket(sock: RawSocket) -> io::Result<RawSocket> {
// First, try the SIO_BASE_HANDLE ioctl.
let result = unsafe { try_socket_ioctl(sock, SIO_BASE_HANDLE) };
match result {
Ok(sock) => return Ok(sock),
Err(e) if e.kind() == io::ErrorKind::InvalidInput => return Err(e),
Err(_) => {}
}
// Some poorly coded LSPs may not handle SIO_BASE_HANDLE properly, but in some cases may
// handle SIO_BSP_HANDLE_POLL better. Try that.
let result = unsafe { try_socket_ioctl(sock, SIO_BSP_HANDLE_POLL)? };
if result == sock {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
// Try `SIO_BASE_HANDLE` again, in case the LSP fixed itself.
unsafe { try_socket_ioctl(result, SIO_BASE_HANDLE) }
}
/// Run an IOCTL on a socket and return a socket.
///
/// # Safety
///
/// The `ioctl` parameter must be a valid I/O control that returns a valid socket.
unsafe fn try_socket_ioctl(sock: RawSocket, ioctl: u32) -> io::Result<RawSocket> {
let mut out = MaybeUninit::<RawSocket>::uninit();
let mut bytes = 0u32;
let result = WSAIoctl(
sock as _,
ioctl,
ptr::null_mut(),
0,
out.as_mut_ptr().cast(),
size_of::<RawSocket>() as u32,
&mut bytes,
ptr::null_mut(),
None,
);
if result == SOCKET_ERROR {
return Err(io::Error::last_os_error());
}
Ok(out.assume_init())
}

834
src/iocp/mod.rs Normal file
View File

@ -0,0 +1,834 @@
//! Bindings to Windows I/O Completion Ports.
//!
//! I/O Completion Ports is a completion-based API rather than a polling-based API, like
//! epoll or kqueue. Therefore, we have to adapt the IOCP API to the crate's API.
//!
//! WinSock is powered by the Auxillary Function Driver (AFD) subsystem, which can be
//! accessed directly by using unstable `ntdll` functions. AFD exposes features that are not
//! available through the normal WinSock interface, such as IOCTL_AFD_POLL. This function is
//! similar to the exposed `WSAPoll` method. However, once the targeted socket is "ready",
//! a completion packet is queued to an I/O completion port.
//!
//! We take advantage of IOCTL_AFD_POLL to "translate" this crate's polling-based API
//! to the one Windows expects. When a device is added to the `Poller`, an IOCTL_AFD_POLL
//! operation is started and queued to the IOCP. To modify a currently registered device
//! (e.g. with `modify()` or `delete()`), the ongoing POLL is cancelled and then restarted
//! with new parameters. Whn the POLL eventually completes, the packet is posted to the IOCP.
//! From here it's a simple matter of using `GetQueuedCompletionStatusEx` to read the packets
//! from the IOCP and react accordingly. Notifying the poller is trivial, because we can
//! simply post a packet to the IOCP to wake it up.
//!
//! The main disadvantage of this strategy is that it relies on unstable Windows APIs.
//! However, as `libuv` (the backing I/O library for Node.JS) relies on the same unstable
//! AFD strategy, it is unlikely to be broken without plenty of advanced warning.
//!
//! Previously, this crate used the `wepoll` library for polling. `wepoll` uses a similar
//! AFD-based strategy for polling.
mod afd;
mod port;
use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock};
use port::{IoCompletionPort, OverlappedEntry};
use windows_sys::Win32::Foundation::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED};
use crate::{Event, PollMode};
use concurrent_queue::ConcurrentQueue;
use pin_project_lite::pin_project;
use std::cell::UnsafeCell;
use std::collections::hash_map::{Entry, HashMap};
use std::fmt;
use std::io;
use std::marker::PhantomPinned;
use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
use std::time::{Duration, Instant};
#[cfg(not(polling_no_io_safety))]
use std::os::windows::io::{AsHandle, BorrowedHandle};
/// Macro to lock and ignore lock poisoning.
macro_rules! lock {
($lock_result:expr) => {{
$lock_result.unwrap_or_else(|e| e.into_inner())
}};
}
/// Interface to I/O completion ports.
#[derive(Debug)]
pub(super) struct Poller {
/// The I/O completion port.
port: IoCompletionPort<Packet>,
/// List of currently active AFD instances.
///
/// Weak references are kept here so that the AFD handle is automatically dropped
/// when the last associated socket is dropped.
afd: Mutex<Vec<Weak<Afd<Packet>>>>,
/// The state of the sources registered with this poller.
sources: RwLock<HashMap<RawSocket, Packet>>,
/// Sockets with pending updates.
pending_updates: ConcurrentQueue<Packet>,
/// Are we currently polling?
polling: AtomicBool,
/// A list of completion packets.
packets: Mutex<Vec<OverlappedEntry<Packet>>>,
/// The packet used to notify the poller.
notifier: Packet,
}
unsafe impl Send for Poller {}
unsafe impl Sync for Poller {}
impl Poller {
/// Creates a new poller.
pub(super) fn new() -> io::Result<Self> {
// Make sure AFD is able to be used.
if let Err(e) = afd::NtdllImports::force_load() {
return Err(crate::unsupported_error(format!(
"Failed to initialize unstable Windows functions: {}\nThis usually only happens for old Windows or Wine.",
e
)));
}
// Create and destroy a single AFD to test if we support it.
Afd::<Packet>::new().map_err(|e| crate::unsupported_error(format!(
"Failed to initialize \\Device\\Afd: {}\nThis usually only happens for old Windows or Wine.",
e,
)))?;
let port = IoCompletionPort::new(0)?;
log::trace!("new: handle={:?}", &port);
Ok(Poller {
port,
afd: Mutex::new(vec![]),
sources: RwLock::new(HashMap::new()),
pending_updates: ConcurrentQueue::bounded(1024),
polling: AtomicBool::new(false),
packets: Mutex::new(Vec::with_capacity(1024)),
notifier: Arc::pin(
PacketInner::Wakeup {
_pinned: PhantomPinned,
}
.into(),
),
})
}
/// Whether this poller supports level-triggered events.
pub(super) fn supports_level(&self) -> bool {
true
}
/// Whether this poller supports edge-triggered events.
pub(super) fn supports_edge(&self) -> bool {
false
}
/// Add a new source to the poller.
pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> {
log::trace!(
"add: handle={:?}, sock={}, ev={:?}",
self.port,
socket,
interest
);
// We don't support edge-triggered events.
if matches!(mode, PollMode::Edge) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"edge-triggered events are not supported",
));
}
// Create a new packet.
let socket_state = {
let state = SocketState {
socket,
base_socket: base_socket(socket)?,
interest,
interest_error: true,
afd: self.afd_handle()?,
mode,
waiting_on_delete: false,
status: SocketStatus::Idle,
};
Arc::pin(IoStatusBlock::from(PacketInner::Socket {
packet: UnsafeCell::new(AfdPollInfo::default()),
socket: Mutex::new(state),
}))
};
// Keep track of the source in the poller.
{
let mut sources = lock!(self.sources.write());
match sources.entry(socket) {
Entry::Vacant(v) => {
v.insert(Pin::<Arc<_>>::clone(&socket_state));
}
Entry::Occupied(_) => {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}
}
}
// Update the packet.
self.update_packet(socket_state)
}
/// Update a source in the poller.
pub(super) fn modify(
&self,
socket: RawSocket,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
log::trace!(
"modify: handle={:?}, sock={}, ev={:?}",
self.port,
socket,
interest
);
// We don't support edge-triggered events.
if matches!(mode, PollMode::Edge) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"edge-triggered events are not supported",
));
}
// Get a reference to the source.
let source = {
let sources = lock!(self.sources.read());
sources
.get(&socket)
.cloned()
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
};
// Set the new event.
if source.as_ref().set_events(interest, mode) {
self.update_packet(source)?;
}
Ok(())
}
/// Delete a source from the poller.
pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> {
log::trace!("remove: handle={:?}, sock={}", self.port, socket);
// Get a reference to the source.
let source = {
let mut sources = lock!(self.sources.write());
match sources.remove(&socket) {
Some(s) => s,
None => {
// If the source has already been removed, then we can just return.
return Ok(());
}
}
};
// Indicate to the source that it is being deleted.
// This cancels any ongoing AFD_IOCTL_POLL operations.
source.begin_delete()
}
/// Wait for events.
pub(super) fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
log::trace!("wait: handle={:?}, timeout={:?}", self.port, timeout);
let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout));
let mut packets = lock!(self.packets.lock());
let mut notified = false;
events.packets.clear();
loop {
let mut new_events = 0;
// Indicate that we are now polling.
let was_polling = self.polling.swap(true, Ordering::SeqCst);
debug_assert!(!was_polling);
let guard = CallOnDrop(|| {
let was_polling = self.polling.swap(false, Ordering::SeqCst);
debug_assert!(was_polling);
});
// Process every entry in the queue before we start polling.
self.drain_update_queue(false)?;
// Get the time to wait for.
let timeout = deadline.map(|t| t.saturating_duration_since(Instant::now()));
// Wait for I/O events.
let len = self.port.wait(&mut packets, timeout)?;
log::trace!("new events: handle={:?}, len={}", self.port, len);
// We are no longer polling.
drop(guard);
// Process all of the events.
for entry in packets.drain(..) {
let packet = entry.into_packet();
// Feed the event into the packet.
match packet.feed_event(self)? {
FeedEventResult::NoEvent => {}
FeedEventResult::Event(event) => {
events.packets.push(event);
new_events += 1;
}
FeedEventResult::Notified => {
notified = true;
}
}
}
// Break if there was a notification or at least one event, or if deadline is reached.
let timeout_is_empty =
timeout.map_or(false, |t| t.as_secs() == 0 && t.subsec_nanos() == 0);
if notified || new_events > 0 || timeout_is_empty {
break;
}
log::trace!("wait: no events found, re-entering polling loop");
}
Ok(())
}
/// Notify this poller.
pub(super) fn notify(&self) -> io::Result<()> {
// Push the notify packet into the IOCP.
self.port.post(0, 0, self.notifier.clone())
}
/// Run an update on a packet.
fn update_packet(&self, mut packet: Packet) -> io::Result<()> {
loop {
// If we are currently polling, we need to update the packet immediately.
if self.polling.load(Ordering::Acquire) {
packet.update()?;
return Ok(());
}
// Try to queue the update.
match self.pending_updates.push(packet) {
Ok(()) => return Ok(()),
Err(p) => packet = p.into_inner(),
}
// If we failed to queue the update, we need to drain the queue first.
self.drain_update_queue(true)?;
}
}
/// Drain the update queue.
fn drain_update_queue(&self, limit: bool) -> io::Result<()> {
let max = if limit {
self.pending_updates.capacity().unwrap()
} else {
std::usize::MAX
};
// Only drain the queue's capacity, since this could in theory run forever.
core::iter::from_fn(|| self.pending_updates.pop().ok())
.take(max)
.try_for_each(|packet| packet.update())
}
/// Get a handle to the AFD reference.
fn afd_handle(&self) -> io::Result<Arc<Afd<Packet>>> {
const AFD_MAX_SIZE: usize = 32;
// Crawl the list and see if there are any existing AFD instances that we can use.
// Remove any unused AFD pointers.
let mut afd_handles = lock!(self.afd.lock());
let mut i = 0;
while i < afd_handles.len() {
// Get the reference count of the AFD instance.
let refcount = Weak::strong_count(&afd_handles[i]);
match refcount {
0 => {
// Prune the AFD pointer if it has no references.
afd_handles.swap_remove(i);
}
refcount if refcount >= AFD_MAX_SIZE => {
// Skip this one, since it is already at the maximum size.
i += 1;
}
_ => {
// We can use this AFD instance.
match afd_handles[i].upgrade() {
Some(afd) => return Ok(afd),
None => {
// The last socket dropped the AFD before we could acquire it.
// Prune the AFD pointer and continue.
afd_handles.swap_remove(i);
}
}
}
}
}
// No available handles, create a new AFD instance.
let afd = Arc::new(Afd::new()?);
// Register the AFD instance with the I/O completion port.
self.port.register(&*afd, true)?;
// Insert a weak pointer to the AFD instance into the list.
afd_handles.push(Arc::downgrade(&afd));
Ok(afd)
}
}
impl AsRawHandle for Poller {
fn as_raw_handle(&self) -> RawHandle {
self.port.as_raw_handle()
}
}
#[cfg(not(polling_no_io_safety))]
impl AsHandle for Poller {
fn as_handle(&self) -> BorrowedHandle<'_> {
unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) }
}
}
/// The container for events.
pub(super) struct Events {
/// List of IOCP packets.
packets: Vec<Event>,
}
unsafe impl Send for Events {}
impl Events {
/// Creates an empty list of events.
pub(super) fn new() -> Events {
Events {
packets: Vec::with_capacity(1024),
}
}
/// Iterate over I/O events.
pub(super) fn iter(&self) -> impl Iterator<Item = Event> + '_ {
self.packets.iter().copied()
}
}
/// The type of our completion packet.
type Packet = Pin<Arc<PacketUnwrapped>>;
type PacketUnwrapped = IoStatusBlock<PacketInner>;
pin_project! {
/// The inner type of the packet.
#[project_ref = PacketInnerProj]
#[project = PacketInnerProjMut]
enum PacketInner {
// A packet for a socket.
Socket {
// The AFD packet state.
#[pin]
packet: UnsafeCell<AfdPollInfo>,
// The socket state.
socket: Mutex<SocketState>
},
// A packet used to wake up the poller.
Wakeup { #[pin] _pinned: PhantomPinned },
}
}
impl fmt::Debug for PacketInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Wakeup { .. } => f.write_str("Wakeup { .. }"),
Self::Socket { socket, .. } => f
.debug_struct("Socket")
.field("packet", &"..")
.field("socket", socket)
.finish(),
}
}
}
impl HasAfdInfo for PacketInner {
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>> {
match self.project_ref() {
PacketInnerProj::Socket { packet, .. } => packet,
PacketInnerProj::Wakeup { .. } => unreachable!(),
}
}
}
impl PacketUnwrapped {
/// Set the new events that this socket is waiting on.
///
/// Returns `true` if we need to be updated.
fn set_events(self: Pin<&Self>, interest: Event, mode: PollMode) -> bool {
let mut socket = match self.socket_state() {
Some(s) => s,
None => return false,
};
socket.interest = interest;
socket.mode = mode;
socket.interest_error = true;
match socket.status {
SocketStatus::Polling { readable, writable } => {
(interest.readable && !readable) || (interest.writable && !writable)
}
_ => true,
}
}
/// Update the socket and install the new status in AFD.
fn update(self: Pin<Arc<Self>>) -> io::Result<()> {
let mut socket = match self.as_ref().socket_state() {
Some(s) => s,
None => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")),
};
// If we are waiting on a delete, just return, dropping the packet.
if socket.waiting_on_delete {
return Ok(());
}
// Check the current status.
match socket.status {
SocketStatus::Polling { readable, writable } => {
// If we need to poll for events aside from what we are currently polling, we need
// to update the packet. Cancel the ongoing poll.
if (socket.interest.readable && !readable)
|| (socket.interest.writable && !writable)
{
return self.cancel(socket);
}
// All events that we are currently waiting on are accounted for.
Ok(())
}
SocketStatus::Cancelled => {
// The ongoing operation was cancelled, and we're still waiting for it to return.
// For now, wait until the top-level loop calls feed_event().
Ok(())
}
SocketStatus::Idle => {
// Start a new poll.
let result = socket.afd.poll(
self.clone(),
socket.base_socket,
event_to_afd_mask(
socket.interest.readable,
socket.interest.writable,
socket.interest_error,
),
);
match result {
Ok(()) => {}
Err(err)
if err.raw_os_error() == Some(ERROR_IO_PENDING as i32)
|| err.kind() == io::ErrorKind::WouldBlock =>
{
// The operation is pending.
}
Err(err) if err.raw_os_error() == Some(ERROR_INVALID_HANDLE as i32) => {
// The socket was closed. We need to delete it.
// This should happen after we drop it here.
}
Err(err) => return Err(err),
}
// We are now polling for the current events.
socket.status = SocketStatus::Polling {
readable: socket.interest.readable,
writable: socket.interest.writable,
};
Ok(())
}
}
}
/// This socket state was notified; see if we need to update it.
fn feed_event(self: Pin<Arc<Self>>, poller: &Poller) -> io::Result<FeedEventResult> {
let inner = self.as_ref().data().project_ref();
let (afd_info, socket) = match inner {
PacketInnerProj::Socket { packet, socket } => (packet, socket),
PacketInnerProj::Wakeup { .. } => {
// The poller was notified.
return Ok(FeedEventResult::Notified);
}
};
let mut socket_state = lock!(socket.lock());
let mut event = Event::none(socket_state.interest.key);
// Put ourselves into the idle state.
socket_state.status = SocketStatus::Idle;
// If we are waiting to be deleted, just return and let the drop handler do their thing.
if socket_state.waiting_on_delete {
return Ok(FeedEventResult::NoEvent);
}
unsafe {
// SAFETY: The packet is not in transit.
let iosb = &mut *self.as_ref().iosb().get();
// Check the status.
match iosb.Anonymous.Status {
STATUS_CANCELLED => {
// Poll request was cancelled.
}
status if status < 0 => {
// There was an error, so we signal both ends.
event.readable = true;
event.writable = true;
}
_ => {
// Check in on the AFD data.
let afd_data = &*afd_info.get();
if afd_data.handle_count() >= 1 {
let events = afd_data.events();
// If we closed the socket, remove it from being polled.
if events.contains(AfdPollMask::LOCAL_CLOSE) {
let source = lock!(poller.sources.write())
.remove(&socket_state.socket)
.unwrap();
return source.begin_delete().map(|()| FeedEventResult::NoEvent);
}
// Report socket-related events.
let (readable, writable) = afd_mask_to_event(events);
event.readable = readable;
event.writable = writable;
}
}
}
}
// Filter out events that the user didn't ask for.
event.readable &= socket_state.interest.readable;
event.writable &= socket_state.interest.writable;
// If this event doesn't have anything that interests us, don't return or
// update the oneshot state.
let return_value = if event.readable || event.writable {
// If we are in oneshot mode, remove the interest.
if matches!(socket_state.mode, PollMode::Oneshot) {
socket_state.interest = Event::none(socket_state.interest.key);
socket_state.interest_error = false;
}
FeedEventResult::Event(event)
} else {
FeedEventResult::NoEvent
};
// Put ourselves in the update queue.
drop(socket_state);
poller.update_packet(self)?;
// Return the event.
Ok(return_value)
}
/// Begin deleting this socket.
fn begin_delete(self: Pin<Arc<Self>>) -> io::Result<()> {
// If we aren't already being deleted, start deleting.
let mut socket = self
.as_ref()
.socket_state()
.expect("can't delete packet that doesn't belong to a socket");
if !socket.waiting_on_delete {
socket.waiting_on_delete = true;
if matches!(socket.status, SocketStatus::Polling { .. }) {
// Cancel the ongoing poll.
self.cancel(socket)?;
}
}
// Either drop it now or wait for it to be dropped later.
Ok(())
}
fn cancel(self: &Pin<Arc<Self>>, mut socket: MutexGuard<'_, SocketState>) -> io::Result<()> {
assert!(matches!(socket.status, SocketStatus::Polling { .. }));
// Send the cancel request.
unsafe {
socket.afd.cancel(self)?;
}
// Move state to cancelled.
socket.status = SocketStatus::Cancelled;
Ok(())
}
fn socket_state(self: Pin<&Self>) -> Option<MutexGuard<'_, SocketState>> {
let inner = self.data().project_ref();
let state = match inner {
PacketInnerProj::Wakeup { .. } => return None,
PacketInnerProj::Socket { socket, .. } => socket,
};
Some(lock!(state.lock()))
}
}
/// Per-socket state.
#[derive(Debug)]
struct SocketState {
/// The raw socket handle.
socket: RawSocket,
/// The base socket handle.
base_socket: RawSocket,
/// The event that this socket is currently waiting on.
interest: Event,
/// Whether to listen for error events.
interest_error: bool,
/// The current poll mode.
mode: PollMode,
/// The AFD instance that this socket is registered with.
afd: Arc<Afd<Packet>>,
/// Whether this socket is waiting to be deleted.
waiting_on_delete: bool,
/// The current status of the socket.
status: SocketStatus,
}
/// The mode that a socket can be in.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum SocketStatus {
/// We are currently not polling.
Idle,
/// We are currently polling these events.
Polling {
/// We are currently polling for readable events.
readable: bool,
/// We are currently polling for writable events.
writable: bool,
},
/// The last poll operation was cancelled, and we're waiting for it to
/// complete.
Cancelled,
}
/// The result of calling `feed_event`.
#[derive(Debug)]
enum FeedEventResult {
/// No event was yielded.
NoEvent,
/// An event was yielded.
Event(Event),
/// The poller has been notified.
Notified,
}
fn event_to_afd_mask(readable: bool, writable: bool, error: bool) -> afd::AfdPollMask {
use afd::AfdPollMask as AfdPoll;
let mut mask = AfdPoll::empty();
if error || readable || writable {
mask |= AfdPoll::ABORT | AfdPoll::CONNECT_FAIL;
}
if readable {
mask |=
AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED;
}
if writable {
mask |= AfdPoll::SEND;
}
mask
}
fn afd_mask_to_event(mask: afd::AfdPollMask) -> (bool, bool) {
use afd::AfdPollMask as AfdPoll;
let mut readable = false;
let mut writable = false;
if mask.intersects(
AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED,
) {
readable = true;
}
if mask.intersects(AfdPoll::SEND) {
writable = true;
}
if mask.intersects(AfdPoll::ABORT | AfdPoll::CONNECT_FAIL) {
readable = true;
writable = true;
}
(readable, writable)
}
struct CallOnDrop<F: FnMut()>(F);
impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}

327
src/iocp/port.rs Normal file
View File

@ -0,0 +1,327 @@
//! A safe wrapper around the Windows I/O API.
use std::convert::{TryFrom, TryInto};
use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::os::windows::io::{AsRawHandle, RawHandle};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use windows_sys::Win32::Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE};
use windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes;
use windows_sys::Win32::System::WindowsProgramming::{FILE_SKIP_SET_EVENT_ON_HANDLE, INFINITE};
use windows_sys::Win32::System::IO::{
CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, OVERLAPPED,
OVERLAPPED_ENTRY,
};
/// A completion block which can be used with I/O completion ports.
///
/// # Safety
///
/// This must be a valid completion block.
pub(super) unsafe trait Completion {
/// Signal to the completion block that we are about to start an operation.
fn try_lock(self: Pin<&Self>) -> bool;
/// Unlock the completion block.
unsafe fn unlock(self: Pin<&Self>);
}
/// The pointer to a completion block.
///
/// # Safety
///
/// This must be a valid completion block.
pub(super) unsafe trait CompletionHandle: Deref + Sized {
/// Type of the completion block.
type Completion: Completion;
/// Get a pointer to the completion block.
///
/// The pointer is pinned since the underlying object should not be moved
/// after creation. This prevents it from being invalidated while it's
/// used in an overlapped operation.
fn get(&self) -> Pin<&Self::Completion>;
/// Convert this block into a pointer that can be passed as `*mut OVERLAPPED`.
fn into_ptr(this: Self) -> *mut OVERLAPPED;
/// Convert a pointer that was passed as `*mut OVERLAPPED` into a pointer to this block.
///
/// # Safety
///
/// This must be a valid pointer to a completion block.
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self;
/// Convert to a pointer without losing ownership.
fn as_ptr(&self) -> *mut OVERLAPPED;
}
unsafe impl<'a, T: Completion> CompletionHandle for Pin<&'a T> {
type Completion = T;
fn get(&self) -> Pin<&Self::Completion> {
*self
}
fn into_ptr(this: Self) -> *mut OVERLAPPED {
unsafe { Pin::into_inner_unchecked(this) as *const T as *mut OVERLAPPED }
}
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self {
Pin::new_unchecked(&*(ptr as *const T))
}
fn as_ptr(&self) -> *mut OVERLAPPED {
self.get_ref() as *const T as *mut OVERLAPPED
}
}
unsafe impl<T: Completion> CompletionHandle for Pin<Arc<T>> {
type Completion = T;
fn get(&self) -> Pin<&Self::Completion> {
self.as_ref()
}
fn into_ptr(this: Self) -> *mut OVERLAPPED {
unsafe { Arc::into_raw(Pin::into_inner_unchecked(this)) as *const T as *mut OVERLAPPED }
}
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self {
Pin::new_unchecked(Arc::from_raw(ptr as *const T))
}
fn as_ptr(&self) -> *mut OVERLAPPED {
self.as_ref().get_ref() as *const T as *mut OVERLAPPED
}
}
/// A handle to the I/O completion port.
pub(super) struct IoCompletionPort<T> {
/// The underlying handle.
handle: HANDLE,
/// We own the status block.
_marker: PhantomData<T>,
}
impl<T> Drop for IoCompletionPort<T> {
fn drop(&mut self) {
unsafe {
CloseHandle(self.handle);
}
}
}
impl<T> AsRawHandle for IoCompletionPort<T> {
fn as_raw_handle(&self) -> RawHandle {
self.handle as _
}
}
impl<T> fmt::Debug for IoCompletionPort<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct WriteAsHex(HANDLE);
impl fmt::Debug for WriteAsHex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:010x}", self.0)
}
}
f.debug_struct("IoCompletionPort")
.field("handle", &WriteAsHex(self.handle))
.finish()
}
}
impl<T: CompletionHandle> IoCompletionPort<T> {
/// Create a new I/O completion port.
pub(super) fn new(threads: usize) -> io::Result<Self> {
let handle = unsafe {
CreateIoCompletionPort(
INVALID_HANDLE_VALUE,
0,
0,
threads.try_into().expect("too many threads"),
)
};
if handle == 0 {
Err(io::Error::last_os_error())
} else {
Ok(Self {
handle,
_marker: PhantomData,
})
}
}
/// Register a handle with this I/O completion port.
pub(super) fn register(
&self,
handle: &impl AsRawHandle, // TODO change to AsHandle
skip_set_event_on_handle: bool,
) -> io::Result<()> {
let handle = handle.as_raw_handle();
let result =
unsafe { CreateIoCompletionPort(handle as _, self.handle, handle as usize, 0) };
if result == 0 {
return Err(io::Error::last_os_error());
}
if skip_set_event_on_handle {
// Set the skip event on handle.
let result = unsafe {
SetFileCompletionNotificationModes(handle as _, FILE_SKIP_SET_EVENT_ON_HANDLE as _)
};
if result == 0 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}
/// Post a completion packet to this port.
pub(super) fn post(&self, bytes_transferred: usize, id: usize, packet: T) -> io::Result<()> {
let result = unsafe {
PostQueuedCompletionStatus(
self.handle,
bytes_transferred
.try_into()
.expect("too many bytes transferred"),
id,
T::into_ptr(packet),
)
};
if result == 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
/// Wait for completion packets to arrive.
pub(super) fn wait(
&self,
packets: &mut Vec<OverlappedEntry<T>>,
timeout: Option<Duration>,
) -> io::Result<usize> {
// Drop the current packets.
packets.clear();
let mut count = MaybeUninit::<u32>::uninit();
let timeout = timeout.map_or(INFINITE, dur2timeout);
let result = unsafe {
GetQueuedCompletionStatusEx(
self.handle,
packets.as_mut_ptr() as _,
packets.capacity().try_into().expect("too many packets"),
count.as_mut_ptr(),
timeout,
0,
)
};
if result == 0 {
let io_error = io::Error::last_os_error();
if io_error.kind() == io::ErrorKind::TimedOut {
Ok(0)
} else {
Err(io_error)
}
} else {
let count = unsafe { count.assume_init() };
unsafe {
packets.set_len(count as _);
}
Ok(count as _)
}
}
}
/// An `OVERLAPPED_ENTRY` resulting from an I/O completion port.
#[repr(transparent)]
pub(super) struct OverlappedEntry<T: CompletionHandle> {
/// The underlying entry.
entry: OVERLAPPED_ENTRY,
/// We own the status block.
_marker: PhantomData<T>,
}
impl<T: CompletionHandle> fmt::Debug for OverlappedEntry<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("OverlappedEntry { .. }")
}
}
impl<T: CompletionHandle> OverlappedEntry<T> {
/// Convert into the completion packet.
pub(super) fn into_packet(self) -> T {
let packet = unsafe { self.packet() };
std::mem::forget(self);
packet
}
/// Get the packet reference that this entry refers to.
///
/// # Safety
///
/// This function should only be called once, since it moves
/// out the `T` from the `OVERLAPPED_ENTRY`.
unsafe fn packet(&self) -> T {
let packet = T::from_ptr(self.entry.lpOverlapped);
packet.get().unlock();
packet
}
}
impl<T: CompletionHandle> Drop for OverlappedEntry<T> {
fn drop(&mut self) {
drop(unsafe { self.packet() });
}
}
// Implementation taken from https://github.com/rust-lang/rust/blob/db5476571d9b27c862b95c1e64764b0ac8980e23/src/libstd/sys/windows/mod.rs
fn dur2timeout(dur: Duration) -> u32 {
// Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
// timeouts in windows APIs are typically u32 milliseconds. To translate, we
// have two pieces to take care of:
//
// * Nanosecond precision is rounded up
// * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE
// (never time out).
dur.as_secs()
.checked_mul(1000)
.and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000))
.and_then(|ms| {
if dur.subsec_nanos() % 1_000_000 > 0 {
ms.checked_add(1)
} else {
Some(ms)
}
})
.and_then(|x| u32::try_from(x).ok())
.unwrap_or(INFINITE)
}
struct CallOnDrop<F: FnMut()>(F);
impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}

View File

@ -1,4 +1,4 @@
//! Portable interface to epoll, kqueue, event ports, and wepoll.
//! Portable interface to epoll, kqueue, event ports, and IOCP.
//!
//! Supported platforms:
//! - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android
@ -6,7 +6,7 @@
//! DragonFly BSD
//! - [event ports](https://illumos.org/man/port_create): illumos, Solaris
//! - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems
//! - [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+)
//! - [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+)
//!
//! By default, polling is done in oneshot mode, which means interest in I/O events needs to
//! be re-enabled after an event is delivered if we're interested in the next event of the same
@ -113,8 +113,8 @@ cfg_if! {
mod poll;
use poll as sys;
} else if #[cfg(target_os = "windows")] {
mod wepoll;
use wepoll as sys;
mod iocp;
use iocp as sys;
} else {
compile_error!("polling does not support this target OS");
}

View File

@ -1,254 +0,0 @@
//! Bindings to wepoll (Windows).
use std::convert::TryInto;
use std::io;
use std::os::raw::c_int;
use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket};
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
#[cfg(not(polling_no_io_safety))]
use std::os::windows::io::{AsHandle, BorrowedHandle};
use wepoll_ffi as we;
use crate::{Event, PollMode};
/// Calls a wepoll function and results in `io::Result`.
macro_rules! wepoll {
($fn:ident $args:tt) => {{
let res = unsafe { we::$fn $args };
if res == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(res)
}
}};
}
/// Interface to wepoll.
#[derive(Debug)]
pub struct Poller {
handle: we::HANDLE,
notified: AtomicBool,
}
unsafe impl Send for Poller {}
unsafe impl Sync for Poller {}
impl Poller {
/// Creates a new poller.
pub fn new() -> io::Result<Poller> {
let handle = unsafe { we::epoll_create1(0) };
if handle.is_null() {
return Err(crate::unsupported_error(
format!(
"Failed to initialize Wepoll: {}\nThis usually only happens for old Windows or Wine.",
io::Error::last_os_error()
)
));
}
let notified = AtomicBool::new(false);
log::trace!("new: handle={:?}", handle);
Ok(Poller { handle, notified })
}
/// Whether this poller supports level-triggered events.
pub fn supports_level(&self) -> bool {
true
}
/// Whether this poller supports edge-triggered events.
pub fn supports_edge(&self) -> bool {
false
}
/// Adds a socket.
pub fn add(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> {
log::trace!("add: handle={:?}, sock={}, ev={:?}", self.handle, sock, ev);
self.ctl(we::EPOLL_CTL_ADD, sock, Some((ev, mode)))
}
/// Modifies a socket.
pub fn modify(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> {
log::trace!(
"modify: handle={:?}, sock={}, ev={:?}",
self.handle,
sock,
ev
);
self.ctl(we::EPOLL_CTL_MOD, sock, Some((ev, mode)))
}
/// Deletes a socket.
pub fn delete(&self, sock: RawSocket) -> io::Result<()> {
log::trace!("remove: handle={:?}, sock={}", self.handle, sock);
self.ctl(we::EPOLL_CTL_DEL, sock, None)
}
/// Waits for I/O events with an optional timeout.
///
/// Returns the number of processed I/O events.
///
/// If a notification occurs, this method will return but the notification event will not be
/// included in the `events` list nor contribute to the returned count.
pub fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
log::trace!("wait: handle={:?}, timeout={:?}", self.handle, timeout);
let deadline = timeout.and_then(|t| Instant::now().checked_add(t));
loop {
// Convert the timeout to milliseconds.
let timeout_ms = match deadline.map(|d| d.saturating_duration_since(Instant::now())) {
None => -1,
Some(t) => {
// Round up to a whole millisecond.
let mut ms = t.as_millis().try_into().unwrap_or(std::u64::MAX);
if Duration::from_millis(ms) < t {
ms = ms.saturating_add(1);
}
ms.try_into().unwrap_or(std::i32::MAX)
}
};
// Wait for I/O events.
events.len = wepoll!(epoll_wait(
self.handle,
events.list.as_mut_ptr(),
events.list.len() as c_int,
timeout_ms,
))? as usize;
log::trace!("new events: handle={:?}, len={}", self.handle, events.len);
// Break if there was a notification or at least one event, or if deadline is reached.
if self.notified.swap(false, Ordering::SeqCst) || events.len > 0 || timeout_ms == 0 {
break;
}
}
Ok(())
}
/// Sends a notification to wake up the current or next `wait()` call.
pub fn notify(&self) -> io::Result<()> {
log::trace!("notify: handle={:?}", self.handle);
if self
.notified
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
unsafe {
// This call errors if a notification has already been posted, but that's okay - we
// can just ignore the error.
//
// The original wepoll does not support notifications triggered this way, which is
// why wepoll-sys includes a small patch to support them.
windows_sys::Win32::System::IO::PostQueuedCompletionStatus(
self.handle as _,
0,
0,
ptr::null_mut(),
);
}
}
Ok(())
}
/// Passes arguments to `epoll_ctl`.
fn ctl(&self, op: u32, sock: RawSocket, ev: Option<(Event, PollMode)>) -> io::Result<()> {
let mut ev = ev
.map(|(ev, mode)| {
let mut flags = match mode {
PollMode::Level => 0,
PollMode::Oneshot => we::EPOLLONESHOT,
PollMode::Edge => {
return Err(crate::unsupported_error(
"edge-triggered events are not supported with wepoll",
));
}
};
if ev.readable {
flags |= READ_FLAGS;
}
if ev.writable {
flags |= WRITE_FLAGS;
}
Ok(we::epoll_event {
events: flags as u32,
data: we::epoll_data {
u64_: ev.key as u64,
},
})
})
.transpose()?;
wepoll!(epoll_ctl(
self.handle,
op as c_int,
sock as we::SOCKET,
ev.as_mut()
.map(|ev| ev as *mut we::epoll_event)
.unwrap_or(ptr::null_mut()),
))?;
Ok(())
}
}
impl AsRawHandle for Poller {
fn as_raw_handle(&self) -> RawHandle {
self.handle as RawHandle
}
}
#[cfg(not(polling_no_io_safety))]
impl AsHandle for Poller {
fn as_handle(&self) -> BorrowedHandle<'_> {
// SAFETY: lifetime is bound by "self"
unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) }
}
}
impl Drop for Poller {
fn drop(&mut self) {
log::trace!("drop: handle={:?}", self.handle);
unsafe {
we::epoll_close(self.handle);
}
}
}
/// Wepoll flags for all possible readability events.
const READ_FLAGS: u32 = we::EPOLLIN | we::EPOLLRDHUP | we::EPOLLHUP | we::EPOLLERR | we::EPOLLPRI;
/// Wepoll flags for all possible writability events.
const WRITE_FLAGS: u32 = we::EPOLLOUT | we::EPOLLHUP | we::EPOLLERR;
/// A list of reported I/O events.
pub struct Events {
list: Box<[we::epoll_event; 1024]>,
len: usize,
}
unsafe impl Send for Events {}
impl Events {
/// Creates an empty list.
pub fn new() -> Events {
let ev = we::epoll_event {
events: 0,
data: we::epoll_data { u64_: 0 },
};
let list = Box::new([ev; 1024]);
Events { list, len: 0 }
}
/// Iterates over I/O events.
pub fn iter(&self) -> impl Iterator<Item = Event> + '_ {
self.list[..self.len].iter().map(|ev| Event {
key: unsafe { ev.data.u64_ } as usize,
readable: (ev.events & READ_FLAGS) != 0,
writable: (ev.events & WRITE_FLAGS) != 0,
})
}
}

View File

@ -43,7 +43,7 @@ fn concurrent_modify() -> io::Result<()> {
Parallel::new()
.add(|| {
poller.wait(&mut events, None)?;
poller.wait(&mut events, Some(Duration::from_secs(10)))?;
Ok(())
})
.add(|| {

38
tests/io.rs Normal file
View File

@ -0,0 +1,38 @@
use polling::{Event, Poller};
use std::io::{self, Write};
use std::net::{TcpListener, TcpStream};
use std::time::Duration;
#[test]
fn basic_io() {
let poller = Poller::new().unwrap();
let (read, mut write) = tcp_pair().unwrap();
poller.add(&read, Event::readable(1)).unwrap();
// Nothing should be available at first.
let mut events = vec![];
assert_eq!(
poller
.wait(&mut events, Some(Duration::from_secs(0)))
.unwrap(),
0
);
assert!(events.is_empty());
// After a write, the event should be available now.
write.write_all(&[1]).unwrap();
assert_eq!(
poller
.wait(&mut events, Some(Duration::from_secs(1)))
.unwrap(),
1
);
assert_eq!(&*events, &[Event::readable(1)]);
}
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let a = TcpStream::connect(listener.local_addr()?)?;
let (b, _) = listener.accept()?;
Ok((a, b))
}

View File

@ -18,7 +18,7 @@ fn below_ms() -> io::Result<()> {
let elapsed = now.elapsed();
assert_eq!(n, 0);
assert!(elapsed >= dur);
assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur);
lowest = lowest.min(elapsed);
}
@ -54,7 +54,7 @@ fn above_ms() -> io::Result<()> {
let elapsed = now.elapsed();
assert_eq!(n, 0);
assert!(elapsed >= dur);
assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur);
lowest = lowest.min(elapsed);
}