mirror of https://github.com/smol-rs/polling
m(windows): Reimplement Wepoll in Rust (#88)
Reimplements the C-based wepoll backend in Rust, using some handwritten code. This PR also implements bindings to the I/O Completion Ports and \Device\Afd APIs. For more information on the latter, see my blog post on the subject: https://notgull.github.io/device-afd/ Note that the IOCP API is wrapped using a `Pin`-oriented "CompletionHandle" system that is relatively brittle. This should be replaced with a better model when one becomes available.
This commit is contained in:
parent
e85331c437
commit
24900fb662
14
Cargo.toml
14
Cargo.toml
|
@ -7,10 +7,10 @@ version = "2.5.2"
|
|||
authors = ["Stjepan Glavina <stjepang@gmail.com>"]
|
||||
edition = "2018"
|
||||
rust-version = "1.47"
|
||||
description = "Portable interface to epoll, kqueue, event ports, and wepoll"
|
||||
description = "Portable interface to epoll, kqueue, event ports, and IOCP"
|
||||
license = "Apache-2.0 OR MIT"
|
||||
repository = "https://github.com/smol-rs/polling"
|
||||
keywords = ["mio", "epoll", "kqueue", "iocp", "wepoll"]
|
||||
keywords = ["mio", "epoll", "kqueue", "iocp"]
|
||||
categories = ["asynchronous", "network-programming", "os"]
|
||||
exclude = ["/.*"]
|
||||
|
||||
|
@ -32,13 +32,19 @@ autocfg = "1"
|
|||
libc = "0.2.77"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
wepoll-ffi = { version = "0.1.2", features = ["null-overlapped-wakeups-patch"] }
|
||||
bitflags = "1.3.2"
|
||||
concurrent-queue = "2.1.0"
|
||||
pin-project-lite = "0.2.9"
|
||||
|
||||
[target.'cfg(windows)'.dependencies.windows-sys]
|
||||
version = "0.45"
|
||||
features = [
|
||||
"Win32_Networking_WinSock",
|
||||
"Win32_System_IO",
|
||||
"Win32_Foundation"
|
||||
"Win32_System_LibraryLoader",
|
||||
"Win32_System_WindowsProgramming",
|
||||
"Win32_Storage_FileSystem",
|
||||
"Win32_Foundation",
|
||||
]
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -9,7 +9,7 @@ https://crates.io/crates/polling)
|
|||
[![Documentation](https://docs.rs/polling/badge.svg)](
|
||||
https://docs.rs/polling)
|
||||
|
||||
Portable interface to epoll, kqueue, event ports, and wepoll.
|
||||
Portable interface to epoll, kqueue, event ports, and IOCP.
|
||||
|
||||
Supported platforms:
|
||||
- [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android
|
||||
|
@ -17,7 +17,7 @@ Supported platforms:
|
|||
DragonFly BSD
|
||||
- [event ports](https://illumos.org/man/port_create): illumos, Solaris
|
||||
- [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems
|
||||
- [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+)
|
||||
- [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+)
|
||||
|
||||
Polling is done in oneshot mode, which means interest in I/O events needs to be reset after
|
||||
an event is delivered if we're interested in the next event of the same kind.
|
||||
|
|
|
@ -0,0 +1,608 @@
|
|||
//! Safe wrapper around \Device\Afd
|
||||
|
||||
use super::port::{Completion, CompletionHandle};
|
||||
|
||||
use std::cell::UnsafeCell;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::marker::{PhantomData, PhantomPinned};
|
||||
use std::mem::{size_of, transmute, MaybeUninit};
|
||||
use std::os::windows::prelude::{AsRawHandle, RawHandle, RawSocket};
|
||||
use std::pin::Pin;
|
||||
use std::ptr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Once;
|
||||
|
||||
use windows_sys::Win32::Foundation::{
|
||||
CloseHandle, HANDLE, HINSTANCE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS,
|
||||
UNICODE_STRING,
|
||||
};
|
||||
use windows_sys::Win32::Networking::WinSock::{
|
||||
WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE_POLL, SOCKET_ERROR,
|
||||
};
|
||||
use windows_sys::Win32::Storage::FileSystem::{
|
||||
FILE_OPEN, FILE_SHARE_READ, FILE_SHARE_WRITE, SYNCHRONIZE,
|
||||
};
|
||||
use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleW, GetProcAddress};
|
||||
use windows_sys::Win32::System::WindowsProgramming::{IO_STATUS_BLOCK, OBJECT_ATTRIBUTES};
|
||||
|
||||
#[derive(Default)]
|
||||
#[repr(C)]
|
||||
pub(super) struct AfdPollInfo {
|
||||
/// The timeout for this poll.
|
||||
timeout: i64,
|
||||
|
||||
/// The number of handles being polled.
|
||||
handle_count: u32,
|
||||
|
||||
/// Whether or not this poll is exclusive for this handle.
|
||||
exclusive: u32,
|
||||
|
||||
/// The handles to poll.
|
||||
handles: [AfdPollHandleInfo; 1],
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
#[repr(C)]
|
||||
struct AfdPollHandleInfo {
|
||||
/// The handle to poll.
|
||||
handle: HANDLE,
|
||||
|
||||
/// The events to poll for.
|
||||
events: AfdPollMask,
|
||||
|
||||
/// The status of the poll.
|
||||
status: NTSTATUS,
|
||||
}
|
||||
|
||||
impl AfdPollInfo {
|
||||
pub(super) fn handle_count(&self) -> u32 {
|
||||
self.handle_count
|
||||
}
|
||||
|
||||
pub(super) fn events(&self) -> AfdPollMask {
|
||||
self.handles[0].events
|
||||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
#[derive(Default)]
|
||||
#[repr(transparent)]
|
||||
pub(super) struct AfdPollMask: u32 {
|
||||
const RECEIVE = 0x001;
|
||||
const RECEIVE_EXPEDITED = 0x002;
|
||||
const SEND = 0x004;
|
||||
const DISCONNECT = 0x008;
|
||||
const ABORT = 0x010;
|
||||
const LOCAL_CLOSE = 0x020;
|
||||
const ACCEPT = 0x080;
|
||||
const CONNECT_FAIL = 0x100;
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) trait HasAfdInfo {
|
||||
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>>;
|
||||
}
|
||||
|
||||
macro_rules! define_ntdll_import {
|
||||
(
|
||||
$(
|
||||
$(#[$attr:meta])*
|
||||
fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty;
|
||||
)*
|
||||
) => {
|
||||
/// Imported functions from ntdll.dll.
|
||||
#[allow(non_snake_case)]
|
||||
pub(super) struct NtdllImports {
|
||||
$(
|
||||
$(#[$attr])*
|
||||
$name: unsafe extern "system" fn($($arg_ty),*) -> $ret,
|
||||
)*
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
impl NtdllImports {
|
||||
unsafe fn load(ntdll: HINSTANCE) -> io::Result<Self> {
|
||||
$(
|
||||
let $name = {
|
||||
const NAME: &str = concat!(stringify!($name), "\0");
|
||||
let addr = GetProcAddress(ntdll, NAME.as_ptr() as *const _);
|
||||
|
||||
let addr = match addr {
|
||||
Some(addr) => addr,
|
||||
None => {
|
||||
log::error!("Failed to load ntdll function {}", NAME);
|
||||
return Err(io::Error::last_os_error());
|
||||
},
|
||||
};
|
||||
|
||||
transmute::<_, unsafe extern "system" fn($($arg_ty),*) -> $ret>(addr)
|
||||
};
|
||||
)*
|
||||
|
||||
Ok(Self {
|
||||
$(
|
||||
$name,
|
||||
)*
|
||||
})
|
||||
}
|
||||
|
||||
$(
|
||||
$(#[$attr])*
|
||||
unsafe fn $name(&self, $($arg: $arg_ty),*) -> $ret {
|
||||
(self.$name)($($arg),*)
|
||||
}
|
||||
)*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
define_ntdll_import! {
|
||||
/// Cancels an ongoing I/O operation.
|
||||
fn NtCancelIoFileEx(
|
||||
FileHandle: HANDLE,
|
||||
IoRequestToCancel: *mut IO_STATUS_BLOCK,
|
||||
IoStatusBlock: *mut IO_STATUS_BLOCK
|
||||
) -> NTSTATUS;
|
||||
|
||||
/// Opens or creates a file handle.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn NtCreateFile(
|
||||
FileHandle: *mut HANDLE,
|
||||
DesiredAccess: u32,
|
||||
ObjectAttributes: *mut OBJECT_ATTRIBUTES,
|
||||
IoStatusBlock: *mut IO_STATUS_BLOCK,
|
||||
AllocationSize: *mut i64,
|
||||
FileAttributes: u32,
|
||||
ShareAccess: u32,
|
||||
CreateDisposition: u32,
|
||||
CreateOptions: u32,
|
||||
EaBuffer: *mut (),
|
||||
EaLength: u32
|
||||
) -> NTSTATUS;
|
||||
|
||||
/// Runs an I/O control on a file handle.
|
||||
///
|
||||
/// Practically equivalent to `ioctl`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn NtDeviceIoControlFile(
|
||||
FileHandle: HANDLE,
|
||||
Event: HANDLE,
|
||||
ApcRoutine: *mut (),
|
||||
ApcContext: *mut (),
|
||||
IoStatusBlock: *mut IO_STATUS_BLOCK,
|
||||
IoControlCode: u32,
|
||||
InputBuffer: *mut (),
|
||||
InputBufferLength: u32,
|
||||
OutputBuffer: *mut (),
|
||||
OutputBufferLength: u32
|
||||
) -> NTSTATUS;
|
||||
|
||||
/// Converts `NTSTATUS` to a DOS error code.
|
||||
fn RtlNtStatusToDosError(
|
||||
Status: NTSTATUS
|
||||
) -> u32;
|
||||
}
|
||||
|
||||
impl NtdllImports {
|
||||
fn get() -> io::Result<&'static Self> {
|
||||
macro_rules! s {
|
||||
($e:expr) => {{
|
||||
$e as u16
|
||||
}};
|
||||
}
|
||||
|
||||
// ntdll.dll
|
||||
static NTDLL_NAME: &[u16] = &[
|
||||
s!('n'),
|
||||
s!('t'),
|
||||
s!('d'),
|
||||
s!('l'),
|
||||
s!('l'),
|
||||
s!('.'),
|
||||
s!('d'),
|
||||
s!('l'),
|
||||
s!('l'),
|
||||
s!('\0'),
|
||||
];
|
||||
static NTDLL_IMPORTS: OnceCell<io::Result<NtdllImports>> = OnceCell::new();
|
||||
|
||||
NTDLL_IMPORTS
|
||||
.get_or_init(|| unsafe {
|
||||
let ntdll = GetModuleHandleW(NTDLL_NAME.as_ptr() as *const _);
|
||||
|
||||
if ntdll == 0 {
|
||||
log::error!("Failed to load ntdll.dll");
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
NtdllImports::load(ntdll)
|
||||
})
|
||||
.as_ref()
|
||||
.map_err(|e| io::Error::from(e.kind()))
|
||||
}
|
||||
|
||||
pub(super) fn force_load() -> io::Result<()> {
|
||||
Self::get()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// The handle to the AFD device.
|
||||
pub(super) struct Afd<T> {
|
||||
/// The handle to the AFD device.
|
||||
handle: HANDLE,
|
||||
|
||||
/// We own `T`.
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for Afd<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
struct WriteAsHex(HANDLE);
|
||||
|
||||
impl fmt::Debug for WriteAsHex {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{:010x}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
f.debug_struct("Afd")
|
||||
.field("handle", &WriteAsHex(self.handle))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Afd<T> {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
CloseHandle(self.handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsRawHandle for Afd<T> {
|
||||
fn as_raw_handle(&self) -> RawHandle {
|
||||
self.handle as _
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CompletionHandle> Afd<T>
|
||||
where
|
||||
T::Completion: AsIoStatusBlock + HasAfdInfo,
|
||||
{
|
||||
/// Create a new AFD device.
|
||||
pub(super) fn new() -> io::Result<Self> {
|
||||
macro_rules! s {
|
||||
($e:expr) => {
|
||||
($e) as u16
|
||||
};
|
||||
}
|
||||
|
||||
/// \Device\Afd\Smol
|
||||
const AFD_NAME: &[u16] = &[
|
||||
s!('\\'),
|
||||
s!('D'),
|
||||
s!('e'),
|
||||
s!('v'),
|
||||
s!('i'),
|
||||
s!('c'),
|
||||
s!('e'),
|
||||
s!('\\'),
|
||||
s!('A'),
|
||||
s!('f'),
|
||||
s!('d'),
|
||||
s!('\\'),
|
||||
s!('S'),
|
||||
s!('m'),
|
||||
s!('o'),
|
||||
s!('l'),
|
||||
s!('\0'),
|
||||
];
|
||||
|
||||
// Set up device attributes.
|
||||
let mut device_name = UNICODE_STRING {
|
||||
Length: (AFD_NAME.len() * size_of::<u16>()) as u16,
|
||||
MaximumLength: (AFD_NAME.len() * size_of::<u16>()) as u16,
|
||||
Buffer: AFD_NAME.as_ptr() as *mut _,
|
||||
};
|
||||
let mut device_attributes = OBJECT_ATTRIBUTES {
|
||||
Length: size_of::<OBJECT_ATTRIBUTES>() as u32,
|
||||
RootDirectory: 0,
|
||||
ObjectName: &mut device_name,
|
||||
Attributes: 0,
|
||||
SecurityDescriptor: ptr::null_mut(),
|
||||
SecurityQualityOfService: ptr::null_mut(),
|
||||
};
|
||||
|
||||
let mut handle = MaybeUninit::<HANDLE>::uninit();
|
||||
let mut iosb = MaybeUninit::<IO_STATUS_BLOCK>::zeroed();
|
||||
let ntdll = NtdllImports::get()?;
|
||||
|
||||
let result = unsafe {
|
||||
ntdll.NtCreateFile(
|
||||
handle.as_mut_ptr(),
|
||||
SYNCHRONIZE,
|
||||
&mut device_attributes,
|
||||
iosb.as_mut_ptr(),
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
FILE_SHARE_READ | FILE_SHARE_WRITE,
|
||||
FILE_OPEN,
|
||||
0,
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
)
|
||||
};
|
||||
|
||||
if result != STATUS_SUCCESS {
|
||||
let real_code = unsafe { ntdll.RtlNtStatusToDosError(result) };
|
||||
|
||||
return Err(io::Error::from_raw_os_error(real_code as i32));
|
||||
}
|
||||
|
||||
let handle = unsafe { handle.assume_init() };
|
||||
|
||||
Ok(Self {
|
||||
handle,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Begin polling with the provided handle.
|
||||
pub(super) fn poll(
|
||||
&self,
|
||||
packet: T,
|
||||
base_socket: RawSocket,
|
||||
afd_events: AfdPollMask,
|
||||
) -> io::Result<()> {
|
||||
const IOCTL_AFD_POLL: u32 = 0x00012024;
|
||||
|
||||
// Lock the packet.
|
||||
if !packet.get().try_lock() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::WouldBlock,
|
||||
"packet is already in use",
|
||||
));
|
||||
}
|
||||
|
||||
// Set up the AFD poll info.
|
||||
let poll_info = unsafe {
|
||||
let poll_info = Pin::into_inner_unchecked(packet.get().afd_info()).get();
|
||||
|
||||
// Initialize the AFD poll info.
|
||||
(*poll_info).exclusive = false.into();
|
||||
(*poll_info).handle_count = 1;
|
||||
(*poll_info).timeout = std::i64::MAX;
|
||||
(*poll_info).handles[0].handle = base_socket as HANDLE;
|
||||
(*poll_info).handles[0].status = 0;
|
||||
(*poll_info).handles[0].events = afd_events;
|
||||
|
||||
poll_info
|
||||
};
|
||||
|
||||
let iosb = T::into_ptr(packet).cast::<IO_STATUS_BLOCK>();
|
||||
// Set Status to pending
|
||||
unsafe {
|
||||
(*iosb).Anonymous.Status = STATUS_PENDING;
|
||||
}
|
||||
|
||||
let ntdll = NtdllImports::get()?;
|
||||
let result = unsafe {
|
||||
ntdll.NtDeviceIoControlFile(
|
||||
self.handle,
|
||||
0,
|
||||
ptr::null_mut(),
|
||||
iosb.cast(),
|
||||
iosb.cast(),
|
||||
IOCTL_AFD_POLL,
|
||||
poll_info.cast(),
|
||||
size_of::<AfdPollInfo>() as u32,
|
||||
poll_info.cast(),
|
||||
size_of::<AfdPollInfo>() as u32,
|
||||
)
|
||||
};
|
||||
|
||||
match result {
|
||||
STATUS_SUCCESS => Ok(()),
|
||||
STATUS_PENDING => Err(io::ErrorKind::WouldBlock.into()),
|
||||
status => {
|
||||
let real_code = unsafe { ntdll.RtlNtStatusToDosError(status) };
|
||||
|
||||
Err(io::Error::from_raw_os_error(real_code as i32))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel an ongoing poll operation.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The poll operation must currently be in progress for this AFD.
|
||||
pub(super) unsafe fn cancel(&self, packet: &T) -> io::Result<()> {
|
||||
let ntdll = NtdllImports::get()?;
|
||||
|
||||
let result = {
|
||||
// First, check if the packet is still in use.
|
||||
let iosb = packet.as_ptr().cast::<IO_STATUS_BLOCK>();
|
||||
|
||||
if (*iosb).Anonymous.Status != STATUS_PENDING {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Cancel the packet.
|
||||
let mut cancel_iosb = MaybeUninit::<IO_STATUS_BLOCK>::zeroed();
|
||||
|
||||
ntdll.NtCancelIoFileEx(self.handle, iosb, cancel_iosb.as_mut_ptr())
|
||||
};
|
||||
|
||||
if result == STATUS_SUCCESS || result == STATUS_NOT_FOUND {
|
||||
Ok(())
|
||||
} else {
|
||||
let real_code = ntdll.RtlNtStatusToDosError(result);
|
||||
|
||||
Err(io::Error::from_raw_os_error(real_code as i32))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A one-time initialization cell.
|
||||
struct OnceCell<T> {
|
||||
/// The value.
|
||||
value: UnsafeCell<MaybeUninit<T>>,
|
||||
|
||||
/// The one-time initialization.
|
||||
once: Once,
|
||||
}
|
||||
|
||||
unsafe impl<T: Send + Sync> Send for OnceCell<T> {}
|
||||
unsafe impl<T: Send + Sync> Sync for OnceCell<T> {}
|
||||
|
||||
impl<T> OnceCell<T> {
|
||||
/// Creates a new `OnceCell`.
|
||||
pub const fn new() -> Self {
|
||||
OnceCell {
|
||||
value: UnsafeCell::new(MaybeUninit::uninit()),
|
||||
once: Once::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the value or initializes it.
|
||||
pub fn get_or_init<F>(&self, f: F) -> &T
|
||||
where
|
||||
F: FnOnce() -> T,
|
||||
{
|
||||
self.once.call_once(|| unsafe {
|
||||
let value = f();
|
||||
*self.value.get() = MaybeUninit::new(value);
|
||||
});
|
||||
|
||||
unsafe { &*self.value.get().cast() }
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// An I/O status block paired with some auxillary data.
|
||||
#[repr(C)]
|
||||
pub(super) struct IoStatusBlock<T> {
|
||||
// The I/O status block.
|
||||
iosb: UnsafeCell<IO_STATUS_BLOCK>,
|
||||
|
||||
// Whether or not the block is in use.
|
||||
in_use: AtomicBool,
|
||||
|
||||
// The auxillary data.
|
||||
#[pin]
|
||||
data: T,
|
||||
|
||||
// This block is not allowed to move.
|
||||
#[pin]
|
||||
_marker: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Debug> fmt::Debug for IoStatusBlock<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("IoStatusBlock")
|
||||
.field("iosb", &"..")
|
||||
.field("in_use", &self.in_use)
|
||||
.field("data", &self.data)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for IoStatusBlock<T> {
|
||||
fn from(data: T) -> Self {
|
||||
Self {
|
||||
iosb: UnsafeCell::new(unsafe { std::mem::zeroed() }),
|
||||
in_use: AtomicBool::new(false),
|
||||
data,
|
||||
_marker: PhantomPinned,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IoStatusBlock<T> {
|
||||
pub(super) fn iosb(self: Pin<&Self>) -> &UnsafeCell<IO_STATUS_BLOCK> {
|
||||
self.project_ref().iosb
|
||||
}
|
||||
|
||||
pub(super) fn data(self: Pin<&Self>) -> Pin<&T> {
|
||||
self.project_ref().data
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: HasAfdInfo> HasAfdInfo for IoStatusBlock<T> {
|
||||
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>> {
|
||||
self.project_ref().data.afd_info()
|
||||
}
|
||||
}
|
||||
|
||||
/// Can be transmuted to an I/O status block.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// A pointer to `T` must be able to be converted to a pointer to `IO_STATUS_BLOCK`
|
||||
/// without any issues.
|
||||
pub(super) unsafe trait AsIoStatusBlock {}
|
||||
|
||||
unsafe impl<T> AsIoStatusBlock for IoStatusBlock<T> {}
|
||||
unsafe impl<T> Completion for IoStatusBlock<T> {
|
||||
fn try_lock(self: Pin<&Self>) -> bool {
|
||||
!self.in_use.swap(true, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
unsafe fn unlock(self: Pin<&Self>) {
|
||||
self.in_use.store(false, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the base socket associated with a socket.
|
||||
pub(super) fn base_socket(sock: RawSocket) -> io::Result<RawSocket> {
|
||||
// First, try the SIO_BASE_HANDLE ioctl.
|
||||
let result = unsafe { try_socket_ioctl(sock, SIO_BASE_HANDLE) };
|
||||
|
||||
match result {
|
||||
Ok(sock) => return Ok(sock),
|
||||
Err(e) if e.kind() == io::ErrorKind::InvalidInput => return Err(e),
|
||||
Err(_) => {}
|
||||
}
|
||||
|
||||
// Some poorly coded LSPs may not handle SIO_BASE_HANDLE properly, but in some cases may
|
||||
// handle SIO_BSP_HANDLE_POLL better. Try that.
|
||||
let result = unsafe { try_socket_ioctl(sock, SIO_BSP_HANDLE_POLL)? };
|
||||
if result == sock {
|
||||
return Err(io::Error::from(io::ErrorKind::InvalidInput));
|
||||
}
|
||||
|
||||
// Try `SIO_BASE_HANDLE` again, in case the LSP fixed itself.
|
||||
unsafe { try_socket_ioctl(result, SIO_BASE_HANDLE) }
|
||||
}
|
||||
|
||||
/// Run an IOCTL on a socket and return a socket.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The `ioctl` parameter must be a valid I/O control that returns a valid socket.
|
||||
unsafe fn try_socket_ioctl(sock: RawSocket, ioctl: u32) -> io::Result<RawSocket> {
|
||||
let mut out = MaybeUninit::<RawSocket>::uninit();
|
||||
let mut bytes = 0u32;
|
||||
|
||||
let result = WSAIoctl(
|
||||
sock as _,
|
||||
ioctl,
|
||||
ptr::null_mut(),
|
||||
0,
|
||||
out.as_mut_ptr().cast(),
|
||||
size_of::<RawSocket>() as u32,
|
||||
&mut bytes,
|
||||
ptr::null_mut(),
|
||||
None,
|
||||
);
|
||||
|
||||
if result == SOCKET_ERROR {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(out.assume_init())
|
||||
}
|
|
@ -0,0 +1,834 @@
|
|||
//! Bindings to Windows I/O Completion Ports.
|
||||
//!
|
||||
//! I/O Completion Ports is a completion-based API rather than a polling-based API, like
|
||||
//! epoll or kqueue. Therefore, we have to adapt the IOCP API to the crate's API.
|
||||
//!
|
||||
//! WinSock is powered by the Auxillary Function Driver (AFD) subsystem, which can be
|
||||
//! accessed directly by using unstable `ntdll` functions. AFD exposes features that are not
|
||||
//! available through the normal WinSock interface, such as IOCTL_AFD_POLL. This function is
|
||||
//! similar to the exposed `WSAPoll` method. However, once the targeted socket is "ready",
|
||||
//! a completion packet is queued to an I/O completion port.
|
||||
//!
|
||||
//! We take advantage of IOCTL_AFD_POLL to "translate" this crate's polling-based API
|
||||
//! to the one Windows expects. When a device is added to the `Poller`, an IOCTL_AFD_POLL
|
||||
//! operation is started and queued to the IOCP. To modify a currently registered device
|
||||
//! (e.g. with `modify()` or `delete()`), the ongoing POLL is cancelled and then restarted
|
||||
//! with new parameters. Whn the POLL eventually completes, the packet is posted to the IOCP.
|
||||
//! From here it's a simple matter of using `GetQueuedCompletionStatusEx` to read the packets
|
||||
//! from the IOCP and react accordingly. Notifying the poller is trivial, because we can
|
||||
//! simply post a packet to the IOCP to wake it up.
|
||||
//!
|
||||
//! The main disadvantage of this strategy is that it relies on unstable Windows APIs.
|
||||
//! However, as `libuv` (the backing I/O library for Node.JS) relies on the same unstable
|
||||
//! AFD strategy, it is unlikely to be broken without plenty of advanced warning.
|
||||
//!
|
||||
//! Previously, this crate used the `wepoll` library for polling. `wepoll` uses a similar
|
||||
//! AFD-based strategy for polling.
|
||||
|
||||
mod afd;
|
||||
mod port;
|
||||
|
||||
use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock};
|
||||
use port::{IoCompletionPort, OverlappedEntry};
|
||||
use windows_sys::Win32::Foundation::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED};
|
||||
|
||||
use crate::{Event, PollMode};
|
||||
|
||||
use concurrent_queue::ConcurrentQueue;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
use std::cell::UnsafeCell;
|
||||
use std::collections::hash_map::{Entry, HashMap};
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::marker::PhantomPinned;
|
||||
use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket};
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
#[cfg(not(polling_no_io_safety))]
|
||||
use std::os::windows::io::{AsHandle, BorrowedHandle};
|
||||
|
||||
/// Macro to lock and ignore lock poisoning.
|
||||
macro_rules! lock {
|
||||
($lock_result:expr) => {{
|
||||
$lock_result.unwrap_or_else(|e| e.into_inner())
|
||||
}};
|
||||
}
|
||||
|
||||
/// Interface to I/O completion ports.
|
||||
#[derive(Debug)]
|
||||
pub(super) struct Poller {
|
||||
/// The I/O completion port.
|
||||
port: IoCompletionPort<Packet>,
|
||||
|
||||
/// List of currently active AFD instances.
|
||||
///
|
||||
/// Weak references are kept here so that the AFD handle is automatically dropped
|
||||
/// when the last associated socket is dropped.
|
||||
afd: Mutex<Vec<Weak<Afd<Packet>>>>,
|
||||
|
||||
/// The state of the sources registered with this poller.
|
||||
sources: RwLock<HashMap<RawSocket, Packet>>,
|
||||
|
||||
/// Sockets with pending updates.
|
||||
pending_updates: ConcurrentQueue<Packet>,
|
||||
|
||||
/// Are we currently polling?
|
||||
polling: AtomicBool,
|
||||
|
||||
/// A list of completion packets.
|
||||
packets: Mutex<Vec<OverlappedEntry<Packet>>>,
|
||||
|
||||
/// The packet used to notify the poller.
|
||||
notifier: Packet,
|
||||
}
|
||||
|
||||
unsafe impl Send for Poller {}
|
||||
unsafe impl Sync for Poller {}
|
||||
|
||||
impl Poller {
|
||||
/// Creates a new poller.
|
||||
pub(super) fn new() -> io::Result<Self> {
|
||||
// Make sure AFD is able to be used.
|
||||
if let Err(e) = afd::NtdllImports::force_load() {
|
||||
return Err(crate::unsupported_error(format!(
|
||||
"Failed to initialize unstable Windows functions: {}\nThis usually only happens for old Windows or Wine.",
|
||||
e
|
||||
)));
|
||||
}
|
||||
|
||||
// Create and destroy a single AFD to test if we support it.
|
||||
Afd::<Packet>::new().map_err(|e| crate::unsupported_error(format!(
|
||||
"Failed to initialize \\Device\\Afd: {}\nThis usually only happens for old Windows or Wine.",
|
||||
e,
|
||||
)))?;
|
||||
|
||||
let port = IoCompletionPort::new(0)?;
|
||||
|
||||
log::trace!("new: handle={:?}", &port);
|
||||
|
||||
Ok(Poller {
|
||||
port,
|
||||
afd: Mutex::new(vec![]),
|
||||
sources: RwLock::new(HashMap::new()),
|
||||
pending_updates: ConcurrentQueue::bounded(1024),
|
||||
polling: AtomicBool::new(false),
|
||||
packets: Mutex::new(Vec::with_capacity(1024)),
|
||||
notifier: Arc::pin(
|
||||
PacketInner::Wakeup {
|
||||
_pinned: PhantomPinned,
|
||||
}
|
||||
.into(),
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
/// Whether this poller supports level-triggered events.
|
||||
pub(super) fn supports_level(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Whether this poller supports edge-triggered events.
|
||||
pub(super) fn supports_edge(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Add a new source to the poller.
|
||||
pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> {
|
||||
log::trace!(
|
||||
"add: handle={:?}, sock={}, ev={:?}",
|
||||
self.port,
|
||||
socket,
|
||||
interest
|
||||
);
|
||||
|
||||
// We don't support edge-triggered events.
|
||||
if matches!(mode, PollMode::Edge) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"edge-triggered events are not supported",
|
||||
));
|
||||
}
|
||||
|
||||
// Create a new packet.
|
||||
let socket_state = {
|
||||
let state = SocketState {
|
||||
socket,
|
||||
base_socket: base_socket(socket)?,
|
||||
interest,
|
||||
interest_error: true,
|
||||
afd: self.afd_handle()?,
|
||||
mode,
|
||||
waiting_on_delete: false,
|
||||
status: SocketStatus::Idle,
|
||||
};
|
||||
|
||||
Arc::pin(IoStatusBlock::from(PacketInner::Socket {
|
||||
packet: UnsafeCell::new(AfdPollInfo::default()),
|
||||
socket: Mutex::new(state),
|
||||
}))
|
||||
};
|
||||
|
||||
// Keep track of the source in the poller.
|
||||
{
|
||||
let mut sources = lock!(self.sources.write());
|
||||
|
||||
match sources.entry(socket) {
|
||||
Entry::Vacant(v) => {
|
||||
v.insert(Pin::<Arc<_>>::clone(&socket_state));
|
||||
}
|
||||
|
||||
Entry::Occupied(_) => {
|
||||
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the packet.
|
||||
self.update_packet(socket_state)
|
||||
}
|
||||
|
||||
/// Update a source in the poller.
|
||||
pub(super) fn modify(
|
||||
&self,
|
||||
socket: RawSocket,
|
||||
interest: Event,
|
||||
mode: PollMode,
|
||||
) -> io::Result<()> {
|
||||
log::trace!(
|
||||
"modify: handle={:?}, sock={}, ev={:?}",
|
||||
self.port,
|
||||
socket,
|
||||
interest
|
||||
);
|
||||
|
||||
// We don't support edge-triggered events.
|
||||
if matches!(mode, PollMode::Edge) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"edge-triggered events are not supported",
|
||||
));
|
||||
}
|
||||
|
||||
// Get a reference to the source.
|
||||
let source = {
|
||||
let sources = lock!(self.sources.read());
|
||||
|
||||
sources
|
||||
.get(&socket)
|
||||
.cloned()
|
||||
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
|
||||
};
|
||||
|
||||
// Set the new event.
|
||||
if source.as_ref().set_events(interest, mode) {
|
||||
self.update_packet(source)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete a source from the poller.
|
||||
pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> {
|
||||
log::trace!("remove: handle={:?}, sock={}", self.port, socket);
|
||||
|
||||
// Get a reference to the source.
|
||||
let source = {
|
||||
let mut sources = lock!(self.sources.write());
|
||||
|
||||
match sources.remove(&socket) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
// If the source has already been removed, then we can just return.
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Indicate to the source that it is being deleted.
|
||||
// This cancels any ongoing AFD_IOCTL_POLL operations.
|
||||
source.begin_delete()
|
||||
}
|
||||
|
||||
/// Wait for events.
|
||||
pub(super) fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
|
||||
log::trace!("wait: handle={:?}, timeout={:?}", self.port, timeout);
|
||||
|
||||
let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout));
|
||||
let mut packets = lock!(self.packets.lock());
|
||||
let mut notified = false;
|
||||
events.packets.clear();
|
||||
|
||||
loop {
|
||||
let mut new_events = 0;
|
||||
|
||||
// Indicate that we are now polling.
|
||||
let was_polling = self.polling.swap(true, Ordering::SeqCst);
|
||||
debug_assert!(!was_polling);
|
||||
|
||||
let guard = CallOnDrop(|| {
|
||||
let was_polling = self.polling.swap(false, Ordering::SeqCst);
|
||||
debug_assert!(was_polling);
|
||||
});
|
||||
|
||||
// Process every entry in the queue before we start polling.
|
||||
self.drain_update_queue(false)?;
|
||||
|
||||
// Get the time to wait for.
|
||||
let timeout = deadline.map(|t| t.saturating_duration_since(Instant::now()));
|
||||
|
||||
// Wait for I/O events.
|
||||
let len = self.port.wait(&mut packets, timeout)?;
|
||||
log::trace!("new events: handle={:?}, len={}", self.port, len);
|
||||
|
||||
// We are no longer polling.
|
||||
drop(guard);
|
||||
|
||||
// Process all of the events.
|
||||
for entry in packets.drain(..) {
|
||||
let packet = entry.into_packet();
|
||||
|
||||
// Feed the event into the packet.
|
||||
match packet.feed_event(self)? {
|
||||
FeedEventResult::NoEvent => {}
|
||||
FeedEventResult::Event(event) => {
|
||||
events.packets.push(event);
|
||||
new_events += 1;
|
||||
}
|
||||
FeedEventResult::Notified => {
|
||||
notified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Break if there was a notification or at least one event, or if deadline is reached.
|
||||
let timeout_is_empty =
|
||||
timeout.map_or(false, |t| t.as_secs() == 0 && t.subsec_nanos() == 0);
|
||||
if notified || new_events > 0 || timeout_is_empty {
|
||||
break;
|
||||
}
|
||||
|
||||
log::trace!("wait: no events found, re-entering polling loop");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Notify this poller.
|
||||
pub(super) fn notify(&self) -> io::Result<()> {
|
||||
// Push the notify packet into the IOCP.
|
||||
self.port.post(0, 0, self.notifier.clone())
|
||||
}
|
||||
|
||||
/// Run an update on a packet.
|
||||
fn update_packet(&self, mut packet: Packet) -> io::Result<()> {
|
||||
loop {
|
||||
// If we are currently polling, we need to update the packet immediately.
|
||||
if self.polling.load(Ordering::Acquire) {
|
||||
packet.update()?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Try to queue the update.
|
||||
match self.pending_updates.push(packet) {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(p) => packet = p.into_inner(),
|
||||
}
|
||||
|
||||
// If we failed to queue the update, we need to drain the queue first.
|
||||
self.drain_update_queue(true)?;
|
||||
}
|
||||
}
|
||||
|
||||
/// Drain the update queue.
|
||||
fn drain_update_queue(&self, limit: bool) -> io::Result<()> {
|
||||
let max = if limit {
|
||||
self.pending_updates.capacity().unwrap()
|
||||
} else {
|
||||
std::usize::MAX
|
||||
};
|
||||
|
||||
// Only drain the queue's capacity, since this could in theory run forever.
|
||||
core::iter::from_fn(|| self.pending_updates.pop().ok())
|
||||
.take(max)
|
||||
.try_for_each(|packet| packet.update())
|
||||
}
|
||||
|
||||
/// Get a handle to the AFD reference.
|
||||
fn afd_handle(&self) -> io::Result<Arc<Afd<Packet>>> {
|
||||
const AFD_MAX_SIZE: usize = 32;
|
||||
|
||||
// Crawl the list and see if there are any existing AFD instances that we can use.
|
||||
// Remove any unused AFD pointers.
|
||||
let mut afd_handles = lock!(self.afd.lock());
|
||||
let mut i = 0;
|
||||
while i < afd_handles.len() {
|
||||
// Get the reference count of the AFD instance.
|
||||
let refcount = Weak::strong_count(&afd_handles[i]);
|
||||
|
||||
match refcount {
|
||||
0 => {
|
||||
// Prune the AFD pointer if it has no references.
|
||||
afd_handles.swap_remove(i);
|
||||
}
|
||||
|
||||
refcount if refcount >= AFD_MAX_SIZE => {
|
||||
// Skip this one, since it is already at the maximum size.
|
||||
i += 1;
|
||||
}
|
||||
|
||||
_ => {
|
||||
// We can use this AFD instance.
|
||||
match afd_handles[i].upgrade() {
|
||||
Some(afd) => return Ok(afd),
|
||||
None => {
|
||||
// The last socket dropped the AFD before we could acquire it.
|
||||
// Prune the AFD pointer and continue.
|
||||
afd_handles.swap_remove(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No available handles, create a new AFD instance.
|
||||
let afd = Arc::new(Afd::new()?);
|
||||
|
||||
// Register the AFD instance with the I/O completion port.
|
||||
self.port.register(&*afd, true)?;
|
||||
|
||||
// Insert a weak pointer to the AFD instance into the list.
|
||||
afd_handles.push(Arc::downgrade(&afd));
|
||||
|
||||
Ok(afd)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawHandle for Poller {
|
||||
fn as_raw_handle(&self) -> RawHandle {
|
||||
self.port.as_raw_handle()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(polling_no_io_safety))]
|
||||
impl AsHandle for Poller {
|
||||
fn as_handle(&self) -> BorrowedHandle<'_> {
|
||||
unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) }
|
||||
}
|
||||
}
|
||||
|
||||
/// The container for events.
|
||||
pub(super) struct Events {
|
||||
/// List of IOCP packets.
|
||||
packets: Vec<Event>,
|
||||
}
|
||||
|
||||
unsafe impl Send for Events {}
|
||||
|
||||
impl Events {
|
||||
/// Creates an empty list of events.
|
||||
pub(super) fn new() -> Events {
|
||||
Events {
|
||||
packets: Vec::with_capacity(1024),
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterate over I/O events.
|
||||
pub(super) fn iter(&self) -> impl Iterator<Item = Event> + '_ {
|
||||
self.packets.iter().copied()
|
||||
}
|
||||
}
|
||||
|
||||
/// The type of our completion packet.
|
||||
type Packet = Pin<Arc<PacketUnwrapped>>;
|
||||
type PacketUnwrapped = IoStatusBlock<PacketInner>;
|
||||
|
||||
pin_project! {
|
||||
/// The inner type of the packet.
|
||||
#[project_ref = PacketInnerProj]
|
||||
#[project = PacketInnerProjMut]
|
||||
enum PacketInner {
|
||||
// A packet for a socket.
|
||||
Socket {
|
||||
// The AFD packet state.
|
||||
#[pin]
|
||||
packet: UnsafeCell<AfdPollInfo>,
|
||||
|
||||
// The socket state.
|
||||
socket: Mutex<SocketState>
|
||||
},
|
||||
|
||||
// A packet used to wake up the poller.
|
||||
Wakeup { #[pin] _pinned: PhantomPinned },
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for PacketInner {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Wakeup { .. } => f.write_str("Wakeup { .. }"),
|
||||
Self::Socket { socket, .. } => f
|
||||
.debug_struct("Socket")
|
||||
.field("packet", &"..")
|
||||
.field("socket", socket)
|
||||
.finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HasAfdInfo for PacketInner {
|
||||
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>> {
|
||||
match self.project_ref() {
|
||||
PacketInnerProj::Socket { packet, .. } => packet,
|
||||
PacketInnerProj::Wakeup { .. } => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketUnwrapped {
|
||||
/// Set the new events that this socket is waiting on.
|
||||
///
|
||||
/// Returns `true` if we need to be updated.
|
||||
fn set_events(self: Pin<&Self>, interest: Event, mode: PollMode) -> bool {
|
||||
let mut socket = match self.socket_state() {
|
||||
Some(s) => s,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
socket.interest = interest;
|
||||
socket.mode = mode;
|
||||
socket.interest_error = true;
|
||||
|
||||
match socket.status {
|
||||
SocketStatus::Polling { readable, writable } => {
|
||||
(interest.readable && !readable) || (interest.writable && !writable)
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the socket and install the new status in AFD.
|
||||
fn update(self: Pin<Arc<Self>>) -> io::Result<()> {
|
||||
let mut socket = match self.as_ref().socket_state() {
|
||||
Some(s) => s,
|
||||
None => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")),
|
||||
};
|
||||
|
||||
// If we are waiting on a delete, just return, dropping the packet.
|
||||
if socket.waiting_on_delete {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check the current status.
|
||||
match socket.status {
|
||||
SocketStatus::Polling { readable, writable } => {
|
||||
// If we need to poll for events aside from what we are currently polling, we need
|
||||
// to update the packet. Cancel the ongoing poll.
|
||||
if (socket.interest.readable && !readable)
|
||||
|| (socket.interest.writable && !writable)
|
||||
{
|
||||
return self.cancel(socket);
|
||||
}
|
||||
|
||||
// All events that we are currently waiting on are accounted for.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SocketStatus::Cancelled => {
|
||||
// The ongoing operation was cancelled, and we're still waiting for it to return.
|
||||
// For now, wait until the top-level loop calls feed_event().
|
||||
Ok(())
|
||||
}
|
||||
|
||||
SocketStatus::Idle => {
|
||||
// Start a new poll.
|
||||
let result = socket.afd.poll(
|
||||
self.clone(),
|
||||
socket.base_socket,
|
||||
event_to_afd_mask(
|
||||
socket.interest.readable,
|
||||
socket.interest.writable,
|
||||
socket.interest_error,
|
||||
),
|
||||
);
|
||||
|
||||
match result {
|
||||
Ok(()) => {}
|
||||
|
||||
Err(err)
|
||||
if err.raw_os_error() == Some(ERROR_IO_PENDING as i32)
|
||||
|| err.kind() == io::ErrorKind::WouldBlock =>
|
||||
{
|
||||
// The operation is pending.
|
||||
}
|
||||
|
||||
Err(err) if err.raw_os_error() == Some(ERROR_INVALID_HANDLE as i32) => {
|
||||
// The socket was closed. We need to delete it.
|
||||
// This should happen after we drop it here.
|
||||
}
|
||||
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
|
||||
// We are now polling for the current events.
|
||||
socket.status = SocketStatus::Polling {
|
||||
readable: socket.interest.readable,
|
||||
writable: socket.interest.writable,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This socket state was notified; see if we need to update it.
|
||||
fn feed_event(self: Pin<Arc<Self>>, poller: &Poller) -> io::Result<FeedEventResult> {
|
||||
let inner = self.as_ref().data().project_ref();
|
||||
|
||||
let (afd_info, socket) = match inner {
|
||||
PacketInnerProj::Socket { packet, socket } => (packet, socket),
|
||||
PacketInnerProj::Wakeup { .. } => {
|
||||
// The poller was notified.
|
||||
return Ok(FeedEventResult::Notified);
|
||||
}
|
||||
};
|
||||
|
||||
let mut socket_state = lock!(socket.lock());
|
||||
let mut event = Event::none(socket_state.interest.key);
|
||||
|
||||
// Put ourselves into the idle state.
|
||||
socket_state.status = SocketStatus::Idle;
|
||||
|
||||
// If we are waiting to be deleted, just return and let the drop handler do their thing.
|
||||
if socket_state.waiting_on_delete {
|
||||
return Ok(FeedEventResult::NoEvent);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// SAFETY: The packet is not in transit.
|
||||
let iosb = &mut *self.as_ref().iosb().get();
|
||||
|
||||
// Check the status.
|
||||
match iosb.Anonymous.Status {
|
||||
STATUS_CANCELLED => {
|
||||
// Poll request was cancelled.
|
||||
}
|
||||
|
||||
status if status < 0 => {
|
||||
// There was an error, so we signal both ends.
|
||||
event.readable = true;
|
||||
event.writable = true;
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Check in on the AFD data.
|
||||
let afd_data = &*afd_info.get();
|
||||
|
||||
if afd_data.handle_count() >= 1 {
|
||||
let events = afd_data.events();
|
||||
|
||||
// If we closed the socket, remove it from being polled.
|
||||
if events.contains(AfdPollMask::LOCAL_CLOSE) {
|
||||
let source = lock!(poller.sources.write())
|
||||
.remove(&socket_state.socket)
|
||||
.unwrap();
|
||||
return source.begin_delete().map(|()| FeedEventResult::NoEvent);
|
||||
}
|
||||
|
||||
// Report socket-related events.
|
||||
let (readable, writable) = afd_mask_to_event(events);
|
||||
event.readable = readable;
|
||||
event.writable = writable;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out events that the user didn't ask for.
|
||||
event.readable &= socket_state.interest.readable;
|
||||
event.writable &= socket_state.interest.writable;
|
||||
|
||||
// If this event doesn't have anything that interests us, don't return or
|
||||
// update the oneshot state.
|
||||
let return_value = if event.readable || event.writable {
|
||||
// If we are in oneshot mode, remove the interest.
|
||||
if matches!(socket_state.mode, PollMode::Oneshot) {
|
||||
socket_state.interest = Event::none(socket_state.interest.key);
|
||||
socket_state.interest_error = false;
|
||||
}
|
||||
|
||||
FeedEventResult::Event(event)
|
||||
} else {
|
||||
FeedEventResult::NoEvent
|
||||
};
|
||||
|
||||
// Put ourselves in the update queue.
|
||||
drop(socket_state);
|
||||
poller.update_packet(self)?;
|
||||
|
||||
// Return the event.
|
||||
Ok(return_value)
|
||||
}
|
||||
|
||||
/// Begin deleting this socket.
|
||||
fn begin_delete(self: Pin<Arc<Self>>) -> io::Result<()> {
|
||||
// If we aren't already being deleted, start deleting.
|
||||
let mut socket = self
|
||||
.as_ref()
|
||||
.socket_state()
|
||||
.expect("can't delete packet that doesn't belong to a socket");
|
||||
if !socket.waiting_on_delete {
|
||||
socket.waiting_on_delete = true;
|
||||
|
||||
if matches!(socket.status, SocketStatus::Polling { .. }) {
|
||||
// Cancel the ongoing poll.
|
||||
self.cancel(socket)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Either drop it now or wait for it to be dropped later.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cancel(self: &Pin<Arc<Self>>, mut socket: MutexGuard<'_, SocketState>) -> io::Result<()> {
|
||||
assert!(matches!(socket.status, SocketStatus::Polling { .. }));
|
||||
|
||||
// Send the cancel request.
|
||||
unsafe {
|
||||
socket.afd.cancel(self)?;
|
||||
}
|
||||
|
||||
// Move state to cancelled.
|
||||
socket.status = SocketStatus::Cancelled;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn socket_state(self: Pin<&Self>) -> Option<MutexGuard<'_, SocketState>> {
|
||||
let inner = self.data().project_ref();
|
||||
|
||||
let state = match inner {
|
||||
PacketInnerProj::Wakeup { .. } => return None,
|
||||
PacketInnerProj::Socket { socket, .. } => socket,
|
||||
};
|
||||
|
||||
Some(lock!(state.lock()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-socket state.
|
||||
#[derive(Debug)]
|
||||
struct SocketState {
|
||||
/// The raw socket handle.
|
||||
socket: RawSocket,
|
||||
|
||||
/// The base socket handle.
|
||||
base_socket: RawSocket,
|
||||
|
||||
/// The event that this socket is currently waiting on.
|
||||
interest: Event,
|
||||
|
||||
/// Whether to listen for error events.
|
||||
interest_error: bool,
|
||||
|
||||
/// The current poll mode.
|
||||
mode: PollMode,
|
||||
|
||||
/// The AFD instance that this socket is registered with.
|
||||
afd: Arc<Afd<Packet>>,
|
||||
|
||||
/// Whether this socket is waiting to be deleted.
|
||||
waiting_on_delete: bool,
|
||||
|
||||
/// The current status of the socket.
|
||||
status: SocketStatus,
|
||||
}
|
||||
|
||||
/// The mode that a socket can be in.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
enum SocketStatus {
|
||||
/// We are currently not polling.
|
||||
Idle,
|
||||
|
||||
/// We are currently polling these events.
|
||||
Polling {
|
||||
/// We are currently polling for readable events.
|
||||
readable: bool,
|
||||
|
||||
/// We are currently polling for writable events.
|
||||
writable: bool,
|
||||
},
|
||||
|
||||
/// The last poll operation was cancelled, and we're waiting for it to
|
||||
/// complete.
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// The result of calling `feed_event`.
|
||||
#[derive(Debug)]
|
||||
enum FeedEventResult {
|
||||
/// No event was yielded.
|
||||
NoEvent,
|
||||
|
||||
/// An event was yielded.
|
||||
Event(Event),
|
||||
|
||||
/// The poller has been notified.
|
||||
Notified,
|
||||
}
|
||||
|
||||
fn event_to_afd_mask(readable: bool, writable: bool, error: bool) -> afd::AfdPollMask {
|
||||
use afd::AfdPollMask as AfdPoll;
|
||||
|
||||
let mut mask = AfdPoll::empty();
|
||||
|
||||
if error || readable || writable {
|
||||
mask |= AfdPoll::ABORT | AfdPoll::CONNECT_FAIL;
|
||||
}
|
||||
|
||||
if readable {
|
||||
mask |=
|
||||
AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED;
|
||||
}
|
||||
|
||||
if writable {
|
||||
mask |= AfdPoll::SEND;
|
||||
}
|
||||
|
||||
mask
|
||||
}
|
||||
|
||||
fn afd_mask_to_event(mask: afd::AfdPollMask) -> (bool, bool) {
|
||||
use afd::AfdPollMask as AfdPoll;
|
||||
|
||||
let mut readable = false;
|
||||
let mut writable = false;
|
||||
|
||||
if mask.intersects(
|
||||
AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED,
|
||||
) {
|
||||
readable = true;
|
||||
}
|
||||
|
||||
if mask.intersects(AfdPoll::SEND) {
|
||||
writable = true;
|
||||
}
|
||||
|
||||
if mask.intersects(AfdPoll::ABORT | AfdPoll::CONNECT_FAIL) {
|
||||
readable = true;
|
||||
writable = true;
|
||||
}
|
||||
|
||||
(readable, writable)
|
||||
}
|
||||
|
||||
struct CallOnDrop<F: FnMut()>(F);
|
||||
|
||||
impl<F: FnMut()> Drop for CallOnDrop<F> {
|
||||
fn drop(&mut self) {
|
||||
(self.0)();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,327 @@
|
|||
//! A safe wrapper around the Windows I/O API.
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::marker::PhantomData;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::ops::Deref;
|
||||
use std::os::windows::io::{AsRawHandle, RawHandle};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use windows_sys::Win32::Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE};
|
||||
use windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes;
|
||||
use windows_sys::Win32::System::WindowsProgramming::{FILE_SKIP_SET_EVENT_ON_HANDLE, INFINITE};
|
||||
use windows_sys::Win32::System::IO::{
|
||||
CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, OVERLAPPED,
|
||||
OVERLAPPED_ENTRY,
|
||||
};
|
||||
|
||||
/// A completion block which can be used with I/O completion ports.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This must be a valid completion block.
|
||||
pub(super) unsafe trait Completion {
|
||||
/// Signal to the completion block that we are about to start an operation.
|
||||
fn try_lock(self: Pin<&Self>) -> bool;
|
||||
|
||||
/// Unlock the completion block.
|
||||
unsafe fn unlock(self: Pin<&Self>);
|
||||
}
|
||||
|
||||
/// The pointer to a completion block.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This must be a valid completion block.
|
||||
pub(super) unsafe trait CompletionHandle: Deref + Sized {
|
||||
/// Type of the completion block.
|
||||
type Completion: Completion;
|
||||
|
||||
/// Get a pointer to the completion block.
|
||||
///
|
||||
/// The pointer is pinned since the underlying object should not be moved
|
||||
/// after creation. This prevents it from being invalidated while it's
|
||||
/// used in an overlapped operation.
|
||||
fn get(&self) -> Pin<&Self::Completion>;
|
||||
|
||||
/// Convert this block into a pointer that can be passed as `*mut OVERLAPPED`.
|
||||
fn into_ptr(this: Self) -> *mut OVERLAPPED;
|
||||
|
||||
/// Convert a pointer that was passed as `*mut OVERLAPPED` into a pointer to this block.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This must be a valid pointer to a completion block.
|
||||
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self;
|
||||
|
||||
/// Convert to a pointer without losing ownership.
|
||||
fn as_ptr(&self) -> *mut OVERLAPPED;
|
||||
}
|
||||
|
||||
unsafe impl<'a, T: Completion> CompletionHandle for Pin<&'a T> {
|
||||
type Completion = T;
|
||||
|
||||
fn get(&self) -> Pin<&Self::Completion> {
|
||||
*self
|
||||
}
|
||||
|
||||
fn into_ptr(this: Self) -> *mut OVERLAPPED {
|
||||
unsafe { Pin::into_inner_unchecked(this) as *const T as *mut OVERLAPPED }
|
||||
}
|
||||
|
||||
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self {
|
||||
Pin::new_unchecked(&*(ptr as *const T))
|
||||
}
|
||||
|
||||
fn as_ptr(&self) -> *mut OVERLAPPED {
|
||||
self.get_ref() as *const T as *mut OVERLAPPED
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<T: Completion> CompletionHandle for Pin<Arc<T>> {
|
||||
type Completion = T;
|
||||
|
||||
fn get(&self) -> Pin<&Self::Completion> {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
fn into_ptr(this: Self) -> *mut OVERLAPPED {
|
||||
unsafe { Arc::into_raw(Pin::into_inner_unchecked(this)) as *const T as *mut OVERLAPPED }
|
||||
}
|
||||
|
||||
unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self {
|
||||
Pin::new_unchecked(Arc::from_raw(ptr as *const T))
|
||||
}
|
||||
|
||||
fn as_ptr(&self) -> *mut OVERLAPPED {
|
||||
self.as_ref().get_ref() as *const T as *mut OVERLAPPED
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to the I/O completion port.
|
||||
pub(super) struct IoCompletionPort<T> {
|
||||
/// The underlying handle.
|
||||
handle: HANDLE,
|
||||
|
||||
/// We own the status block.
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Drop for IoCompletionPort<T> {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
CloseHandle(self.handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsRawHandle for IoCompletionPort<T> {
|
||||
fn as_raw_handle(&self) -> RawHandle {
|
||||
self.handle as _
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for IoCompletionPort<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
struct WriteAsHex(HANDLE);
|
||||
|
||||
impl fmt::Debug for WriteAsHex {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{:010x}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
f.debug_struct("IoCompletionPort")
|
||||
.field("handle", &WriteAsHex(self.handle))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CompletionHandle> IoCompletionPort<T> {
|
||||
/// Create a new I/O completion port.
|
||||
pub(super) fn new(threads: usize) -> io::Result<Self> {
|
||||
let handle = unsafe {
|
||||
CreateIoCompletionPort(
|
||||
INVALID_HANDLE_VALUE,
|
||||
0,
|
||||
0,
|
||||
threads.try_into().expect("too many threads"),
|
||||
)
|
||||
};
|
||||
|
||||
if handle == 0 {
|
||||
Err(io::Error::last_os_error())
|
||||
} else {
|
||||
Ok(Self {
|
||||
handle,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a handle with this I/O completion port.
|
||||
pub(super) fn register(
|
||||
&self,
|
||||
handle: &impl AsRawHandle, // TODO change to AsHandle
|
||||
skip_set_event_on_handle: bool,
|
||||
) -> io::Result<()> {
|
||||
let handle = handle.as_raw_handle();
|
||||
|
||||
let result =
|
||||
unsafe { CreateIoCompletionPort(handle as _, self.handle, handle as usize, 0) };
|
||||
|
||||
if result == 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
if skip_set_event_on_handle {
|
||||
// Set the skip event on handle.
|
||||
let result = unsafe {
|
||||
SetFileCompletionNotificationModes(handle as _, FILE_SKIP_SET_EVENT_ON_HANDLE as _)
|
||||
};
|
||||
|
||||
if result == 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Post a completion packet to this port.
|
||||
pub(super) fn post(&self, bytes_transferred: usize, id: usize, packet: T) -> io::Result<()> {
|
||||
let result = unsafe {
|
||||
PostQueuedCompletionStatus(
|
||||
self.handle,
|
||||
bytes_transferred
|
||||
.try_into()
|
||||
.expect("too many bytes transferred"),
|
||||
id,
|
||||
T::into_ptr(packet),
|
||||
)
|
||||
};
|
||||
|
||||
if result == 0 {
|
||||
Err(io::Error::last_os_error())
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for completion packets to arrive.
|
||||
pub(super) fn wait(
|
||||
&self,
|
||||
packets: &mut Vec<OverlappedEntry<T>>,
|
||||
timeout: Option<Duration>,
|
||||
) -> io::Result<usize> {
|
||||
// Drop the current packets.
|
||||
packets.clear();
|
||||
|
||||
let mut count = MaybeUninit::<u32>::uninit();
|
||||
let timeout = timeout.map_or(INFINITE, dur2timeout);
|
||||
|
||||
let result = unsafe {
|
||||
GetQueuedCompletionStatusEx(
|
||||
self.handle,
|
||||
packets.as_mut_ptr() as _,
|
||||
packets.capacity().try_into().expect("too many packets"),
|
||||
count.as_mut_ptr(),
|
||||
timeout,
|
||||
0,
|
||||
)
|
||||
};
|
||||
|
||||
if result == 0 {
|
||||
let io_error = io::Error::last_os_error();
|
||||
if io_error.kind() == io::ErrorKind::TimedOut {
|
||||
Ok(0)
|
||||
} else {
|
||||
Err(io_error)
|
||||
}
|
||||
} else {
|
||||
let count = unsafe { count.assume_init() };
|
||||
unsafe {
|
||||
packets.set_len(count as _);
|
||||
}
|
||||
Ok(count as _)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An `OVERLAPPED_ENTRY` resulting from an I/O completion port.
|
||||
#[repr(transparent)]
|
||||
pub(super) struct OverlappedEntry<T: CompletionHandle> {
|
||||
/// The underlying entry.
|
||||
entry: OVERLAPPED_ENTRY,
|
||||
|
||||
/// We own the status block.
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: CompletionHandle> fmt::Debug for OverlappedEntry<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("OverlappedEntry { .. }")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CompletionHandle> OverlappedEntry<T> {
|
||||
/// Convert into the completion packet.
|
||||
pub(super) fn into_packet(self) -> T {
|
||||
let packet = unsafe { self.packet() };
|
||||
std::mem::forget(self);
|
||||
packet
|
||||
}
|
||||
|
||||
/// Get the packet reference that this entry refers to.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This function should only be called once, since it moves
|
||||
/// out the `T` from the `OVERLAPPED_ENTRY`.
|
||||
unsafe fn packet(&self) -> T {
|
||||
let packet = T::from_ptr(self.entry.lpOverlapped);
|
||||
packet.get().unlock();
|
||||
packet
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CompletionHandle> Drop for OverlappedEntry<T> {
|
||||
fn drop(&mut self) {
|
||||
drop(unsafe { self.packet() });
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation taken from https://github.com/rust-lang/rust/blob/db5476571d9b27c862b95c1e64764b0ac8980e23/src/libstd/sys/windows/mod.rs
|
||||
fn dur2timeout(dur: Duration) -> u32 {
|
||||
// Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the
|
||||
// timeouts in windows APIs are typically u32 milliseconds. To translate, we
|
||||
// have two pieces to take care of:
|
||||
//
|
||||
// * Nanosecond precision is rounded up
|
||||
// * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE
|
||||
// (never time out).
|
||||
dur.as_secs()
|
||||
.checked_mul(1000)
|
||||
.and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000))
|
||||
.and_then(|ms| {
|
||||
if dur.subsec_nanos() % 1_000_000 > 0 {
|
||||
ms.checked_add(1)
|
||||
} else {
|
||||
Some(ms)
|
||||
}
|
||||
})
|
||||
.and_then(|x| u32::try_from(x).ok())
|
||||
.unwrap_or(INFINITE)
|
||||
}
|
||||
|
||||
struct CallOnDrop<F: FnMut()>(F);
|
||||
|
||||
impl<F: FnMut()> Drop for CallOnDrop<F> {
|
||||
fn drop(&mut self) {
|
||||
(self.0)();
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//! Portable interface to epoll, kqueue, event ports, and wepoll.
|
||||
//! Portable interface to epoll, kqueue, event ports, and IOCP.
|
||||
//!
|
||||
//! Supported platforms:
|
||||
//! - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android
|
||||
|
@ -6,7 +6,7 @@
|
|||
//! DragonFly BSD
|
||||
//! - [event ports](https://illumos.org/man/port_create): illumos, Solaris
|
||||
//! - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems
|
||||
//! - [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+)
|
||||
//! - [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+)
|
||||
//!
|
||||
//! By default, polling is done in oneshot mode, which means interest in I/O events needs to
|
||||
//! be re-enabled after an event is delivered if we're interested in the next event of the same
|
||||
|
@ -113,8 +113,8 @@ cfg_if! {
|
|||
mod poll;
|
||||
use poll as sys;
|
||||
} else if #[cfg(target_os = "windows")] {
|
||||
mod wepoll;
|
||||
use wepoll as sys;
|
||||
mod iocp;
|
||||
use iocp as sys;
|
||||
} else {
|
||||
compile_error!("polling does not support this target OS");
|
||||
}
|
||||
|
|
254
src/wepoll.rs
254
src/wepoll.rs
|
@ -1,254 +0,0 @@
|
|||
//! Bindings to wepoll (Windows).
|
||||
|
||||
use std::convert::TryInto;
|
||||
use std::io;
|
||||
use std::os::raw::c_int;
|
||||
use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket};
|
||||
use std::ptr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
#[cfg(not(polling_no_io_safety))]
|
||||
use std::os::windows::io::{AsHandle, BorrowedHandle};
|
||||
|
||||
use wepoll_ffi as we;
|
||||
|
||||
use crate::{Event, PollMode};
|
||||
|
||||
/// Calls a wepoll function and results in `io::Result`.
|
||||
macro_rules! wepoll {
|
||||
($fn:ident $args:tt) => {{
|
||||
let res = unsafe { we::$fn $args };
|
||||
if res == -1 {
|
||||
Err(std::io::Error::last_os_error())
|
||||
} else {
|
||||
Ok(res)
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/// Interface to wepoll.
|
||||
#[derive(Debug)]
|
||||
pub struct Poller {
|
||||
handle: we::HANDLE,
|
||||
notified: AtomicBool,
|
||||
}
|
||||
|
||||
unsafe impl Send for Poller {}
|
||||
unsafe impl Sync for Poller {}
|
||||
|
||||
impl Poller {
|
||||
/// Creates a new poller.
|
||||
pub fn new() -> io::Result<Poller> {
|
||||
let handle = unsafe { we::epoll_create1(0) };
|
||||
if handle.is_null() {
|
||||
return Err(crate::unsupported_error(
|
||||
format!(
|
||||
"Failed to initialize Wepoll: {}\nThis usually only happens for old Windows or Wine.",
|
||||
io::Error::last_os_error()
|
||||
)
|
||||
));
|
||||
}
|
||||
let notified = AtomicBool::new(false);
|
||||
log::trace!("new: handle={:?}", handle);
|
||||
Ok(Poller { handle, notified })
|
||||
}
|
||||
|
||||
/// Whether this poller supports level-triggered events.
|
||||
pub fn supports_level(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Whether this poller supports edge-triggered events.
|
||||
pub fn supports_edge(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Adds a socket.
|
||||
pub fn add(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> {
|
||||
log::trace!("add: handle={:?}, sock={}, ev={:?}", self.handle, sock, ev);
|
||||
self.ctl(we::EPOLL_CTL_ADD, sock, Some((ev, mode)))
|
||||
}
|
||||
|
||||
/// Modifies a socket.
|
||||
pub fn modify(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> {
|
||||
log::trace!(
|
||||
"modify: handle={:?}, sock={}, ev={:?}",
|
||||
self.handle,
|
||||
sock,
|
||||
ev
|
||||
);
|
||||
self.ctl(we::EPOLL_CTL_MOD, sock, Some((ev, mode)))
|
||||
}
|
||||
|
||||
/// Deletes a socket.
|
||||
pub fn delete(&self, sock: RawSocket) -> io::Result<()> {
|
||||
log::trace!("remove: handle={:?}, sock={}", self.handle, sock);
|
||||
self.ctl(we::EPOLL_CTL_DEL, sock, None)
|
||||
}
|
||||
|
||||
/// Waits for I/O events with an optional timeout.
|
||||
///
|
||||
/// Returns the number of processed I/O events.
|
||||
///
|
||||
/// If a notification occurs, this method will return but the notification event will not be
|
||||
/// included in the `events` list nor contribute to the returned count.
|
||||
pub fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
|
||||
log::trace!("wait: handle={:?}, timeout={:?}", self.handle, timeout);
|
||||
let deadline = timeout.and_then(|t| Instant::now().checked_add(t));
|
||||
|
||||
loop {
|
||||
// Convert the timeout to milliseconds.
|
||||
let timeout_ms = match deadline.map(|d| d.saturating_duration_since(Instant::now())) {
|
||||
None => -1,
|
||||
Some(t) => {
|
||||
// Round up to a whole millisecond.
|
||||
let mut ms = t.as_millis().try_into().unwrap_or(std::u64::MAX);
|
||||
if Duration::from_millis(ms) < t {
|
||||
ms = ms.saturating_add(1);
|
||||
}
|
||||
ms.try_into().unwrap_or(std::i32::MAX)
|
||||
}
|
||||
};
|
||||
|
||||
// Wait for I/O events.
|
||||
events.len = wepoll!(epoll_wait(
|
||||
self.handle,
|
||||
events.list.as_mut_ptr(),
|
||||
events.list.len() as c_int,
|
||||
timeout_ms,
|
||||
))? as usize;
|
||||
log::trace!("new events: handle={:?}, len={}", self.handle, events.len);
|
||||
|
||||
// Break if there was a notification or at least one event, or if deadline is reached.
|
||||
if self.notified.swap(false, Ordering::SeqCst) || events.len > 0 || timeout_ms == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sends a notification to wake up the current or next `wait()` call.
|
||||
pub fn notify(&self) -> io::Result<()> {
|
||||
log::trace!("notify: handle={:?}", self.handle);
|
||||
|
||||
if self
|
||||
.notified
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_ok()
|
||||
{
|
||||
unsafe {
|
||||
// This call errors if a notification has already been posted, but that's okay - we
|
||||
// can just ignore the error.
|
||||
//
|
||||
// The original wepoll does not support notifications triggered this way, which is
|
||||
// why wepoll-sys includes a small patch to support them.
|
||||
windows_sys::Win32::System::IO::PostQueuedCompletionStatus(
|
||||
self.handle as _,
|
||||
0,
|
||||
0,
|
||||
ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Passes arguments to `epoll_ctl`.
|
||||
fn ctl(&self, op: u32, sock: RawSocket, ev: Option<(Event, PollMode)>) -> io::Result<()> {
|
||||
let mut ev = ev
|
||||
.map(|(ev, mode)| {
|
||||
let mut flags = match mode {
|
||||
PollMode::Level => 0,
|
||||
PollMode::Oneshot => we::EPOLLONESHOT,
|
||||
PollMode::Edge => {
|
||||
return Err(crate::unsupported_error(
|
||||
"edge-triggered events are not supported with wepoll",
|
||||
));
|
||||
}
|
||||
};
|
||||
if ev.readable {
|
||||
flags |= READ_FLAGS;
|
||||
}
|
||||
if ev.writable {
|
||||
flags |= WRITE_FLAGS;
|
||||
}
|
||||
|
||||
Ok(we::epoll_event {
|
||||
events: flags as u32,
|
||||
data: we::epoll_data {
|
||||
u64_: ev.key as u64,
|
||||
},
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
wepoll!(epoll_ctl(
|
||||
self.handle,
|
||||
op as c_int,
|
||||
sock as we::SOCKET,
|
||||
ev.as_mut()
|
||||
.map(|ev| ev as *mut we::epoll_event)
|
||||
.unwrap_or(ptr::null_mut()),
|
||||
))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRawHandle for Poller {
|
||||
fn as_raw_handle(&self) -> RawHandle {
|
||||
self.handle as RawHandle
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(polling_no_io_safety))]
|
||||
impl AsHandle for Poller {
|
||||
fn as_handle(&self) -> BorrowedHandle<'_> {
|
||||
// SAFETY: lifetime is bound by "self"
|
||||
unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Poller {
|
||||
fn drop(&mut self) {
|
||||
log::trace!("drop: handle={:?}", self.handle);
|
||||
unsafe {
|
||||
we::epoll_close(self.handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wepoll flags for all possible readability events.
|
||||
const READ_FLAGS: u32 = we::EPOLLIN | we::EPOLLRDHUP | we::EPOLLHUP | we::EPOLLERR | we::EPOLLPRI;
|
||||
|
||||
/// Wepoll flags for all possible writability events.
|
||||
const WRITE_FLAGS: u32 = we::EPOLLOUT | we::EPOLLHUP | we::EPOLLERR;
|
||||
|
||||
/// A list of reported I/O events.
|
||||
pub struct Events {
|
||||
list: Box<[we::epoll_event; 1024]>,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
unsafe impl Send for Events {}
|
||||
|
||||
impl Events {
|
||||
/// Creates an empty list.
|
||||
pub fn new() -> Events {
|
||||
let ev = we::epoll_event {
|
||||
events: 0,
|
||||
data: we::epoll_data { u64_: 0 },
|
||||
};
|
||||
let list = Box::new([ev; 1024]);
|
||||
Events { list, len: 0 }
|
||||
}
|
||||
|
||||
/// Iterates over I/O events.
|
||||
pub fn iter(&self) -> impl Iterator<Item = Event> + '_ {
|
||||
self.list[..self.len].iter().map(|ev| Event {
|
||||
key: unsafe { ev.data.u64_ } as usize,
|
||||
readable: (ev.events & READ_FLAGS) != 0,
|
||||
writable: (ev.events & WRITE_FLAGS) != 0,
|
||||
})
|
||||
}
|
||||
}
|
|
@ -43,7 +43,7 @@ fn concurrent_modify() -> io::Result<()> {
|
|||
|
||||
Parallel::new()
|
||||
.add(|| {
|
||||
poller.wait(&mut events, None)?;
|
||||
poller.wait(&mut events, Some(Duration::from_secs(10)))?;
|
||||
Ok(())
|
||||
})
|
||||
.add(|| {
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
use polling::{Event, Poller};
|
||||
use std::io::{self, Write};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn basic_io() {
|
||||
let poller = Poller::new().unwrap();
|
||||
let (read, mut write) = tcp_pair().unwrap();
|
||||
poller.add(&read, Event::readable(1)).unwrap();
|
||||
|
||||
// Nothing should be available at first.
|
||||
let mut events = vec![];
|
||||
assert_eq!(
|
||||
poller
|
||||
.wait(&mut events, Some(Duration::from_secs(0)))
|
||||
.unwrap(),
|
||||
0
|
||||
);
|
||||
assert!(events.is_empty());
|
||||
|
||||
// After a write, the event should be available now.
|
||||
write.write_all(&[1]).unwrap();
|
||||
assert_eq!(
|
||||
poller
|
||||
.wait(&mut events, Some(Duration::from_secs(1)))
|
||||
.unwrap(),
|
||||
1
|
||||
);
|
||||
assert_eq!(&*events, &[Event::readable(1)]);
|
||||
}
|
||||
|
||||
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let a = TcpStream::connect(listener.local_addr()?)?;
|
||||
let (b, _) = listener.accept()?;
|
||||
Ok((a, b))
|
||||
}
|
|
@ -18,7 +18,7 @@ fn below_ms() -> io::Result<()> {
|
|||
let elapsed = now.elapsed();
|
||||
|
||||
assert_eq!(n, 0);
|
||||
assert!(elapsed >= dur);
|
||||
assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur);
|
||||
lowest = lowest.min(elapsed);
|
||||
}
|
||||
|
||||
|
@ -54,7 +54,7 @@ fn above_ms() -> io::Result<()> {
|
|||
let elapsed = now.elapsed();
|
||||
|
||||
assert_eq!(n, 0);
|
||||
assert!(elapsed >= dur);
|
||||
assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur);
|
||||
lowest = lowest.min(elapsed);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue