diff --git a/Cargo.toml b/Cargo.toml index ac0b0ac..3e86495 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,10 +7,10 @@ version = "2.5.2" authors = ["Stjepan Glavina "] edition = "2018" rust-version = "1.47" -description = "Portable interface to epoll, kqueue, event ports, and wepoll" +description = "Portable interface to epoll, kqueue, event ports, and IOCP" license = "Apache-2.0 OR MIT" repository = "https://github.com/smol-rs/polling" -keywords = ["mio", "epoll", "kqueue", "iocp", "wepoll"] +keywords = ["mio", "epoll", "kqueue", "iocp"] categories = ["asynchronous", "network-programming", "os"] exclude = ["/.*"] @@ -32,13 +32,19 @@ autocfg = "1" libc = "0.2.77" [target.'cfg(windows)'.dependencies] -wepoll-ffi = { version = "0.1.2", features = ["null-overlapped-wakeups-patch"] } +bitflags = "1.3.2" +concurrent-queue = "2.1.0" +pin-project-lite = "0.2.9" [target.'cfg(windows)'.dependencies.windows-sys] version = "0.45" features = [ + "Win32_Networking_WinSock", "Win32_System_IO", - "Win32_Foundation" + "Win32_System_LibraryLoader", + "Win32_System_WindowsProgramming", + "Win32_Storage_FileSystem", + "Win32_Foundation", ] [dev-dependencies] diff --git a/README.md b/README.md index 9d54a5c..67093bf 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ https://crates.io/crates/polling) [![Documentation](https://docs.rs/polling/badge.svg)]( https://docs.rs/polling) -Portable interface to epoll, kqueue, event ports, and wepoll. +Portable interface to epoll, kqueue, event ports, and IOCP. Supported platforms: - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android @@ -17,7 +17,7 @@ Supported platforms: DragonFly BSD - [event ports](https://illumos.org/man/port_create): illumos, Solaris - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems -- [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+) +- [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+) Polling is done in oneshot mode, which means interest in I/O events needs to be reset after an event is delivered if we're interested in the next event of the same kind. diff --git a/src/iocp/afd.rs b/src/iocp/afd.rs new file mode 100644 index 0000000..4a6600f --- /dev/null +++ b/src/iocp/afd.rs @@ -0,0 +1,608 @@ +//! Safe wrapper around \Device\Afd + +use super::port::{Completion, CompletionHandle}; + +use std::cell::UnsafeCell; +use std::fmt; +use std::io; +use std::marker::{PhantomData, PhantomPinned}; +use std::mem::{size_of, transmute, MaybeUninit}; +use std::os::windows::prelude::{AsRawHandle, RawHandle, RawSocket}; +use std::pin::Pin; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Once; + +use windows_sys::Win32::Foundation::{ + CloseHandle, HANDLE, HINSTANCE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS, + UNICODE_STRING, +}; +use windows_sys::Win32::Networking::WinSock::{ + WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE_POLL, SOCKET_ERROR, +}; +use windows_sys::Win32::Storage::FileSystem::{ + FILE_OPEN, FILE_SHARE_READ, FILE_SHARE_WRITE, SYNCHRONIZE, +}; +use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleW, GetProcAddress}; +use windows_sys::Win32::System::WindowsProgramming::{IO_STATUS_BLOCK, OBJECT_ATTRIBUTES}; + +#[derive(Default)] +#[repr(C)] +pub(super) struct AfdPollInfo { + /// The timeout for this poll. + timeout: i64, + + /// The number of handles being polled. + handle_count: u32, + + /// Whether or not this poll is exclusive for this handle. + exclusive: u32, + + /// The handles to poll. + handles: [AfdPollHandleInfo; 1], +} + +#[derive(Default)] +#[repr(C)] +struct AfdPollHandleInfo { + /// The handle to poll. + handle: HANDLE, + + /// The events to poll for. + events: AfdPollMask, + + /// The status of the poll. + status: NTSTATUS, +} + +impl AfdPollInfo { + pub(super) fn handle_count(&self) -> u32 { + self.handle_count + } + + pub(super) fn events(&self) -> AfdPollMask { + self.handles[0].events + } +} + +bitflags::bitflags! { + #[derive(Default)] + #[repr(transparent)] + pub(super) struct AfdPollMask: u32 { + const RECEIVE = 0x001; + const RECEIVE_EXPEDITED = 0x002; + const SEND = 0x004; + const DISCONNECT = 0x008; + const ABORT = 0x010; + const LOCAL_CLOSE = 0x020; + const ACCEPT = 0x080; + const CONNECT_FAIL = 0x100; + } +} + +pub(super) trait HasAfdInfo { + fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell>; +} + +macro_rules! define_ntdll_import { + ( + $( + $(#[$attr:meta])* + fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty; + )* + ) => { + /// Imported functions from ntdll.dll. + #[allow(non_snake_case)] + pub(super) struct NtdllImports { + $( + $(#[$attr])* + $name: unsafe extern "system" fn($($arg_ty),*) -> $ret, + )* + } + + #[allow(non_snake_case)] + impl NtdllImports { + unsafe fn load(ntdll: HINSTANCE) -> io::Result { + $( + let $name = { + const NAME: &str = concat!(stringify!($name), "\0"); + let addr = GetProcAddress(ntdll, NAME.as_ptr() as *const _); + + let addr = match addr { + Some(addr) => addr, + None => { + log::error!("Failed to load ntdll function {}", NAME); + return Err(io::Error::last_os_error()); + }, + }; + + transmute::<_, unsafe extern "system" fn($($arg_ty),*) -> $ret>(addr) + }; + )* + + Ok(Self { + $( + $name, + )* + }) + } + + $( + $(#[$attr])* + unsafe fn $name(&self, $($arg: $arg_ty),*) -> $ret { + (self.$name)($($arg),*) + } + )* + } + }; +} + +define_ntdll_import! { + /// Cancels an ongoing I/O operation. + fn NtCancelIoFileEx( + FileHandle: HANDLE, + IoRequestToCancel: *mut IO_STATUS_BLOCK, + IoStatusBlock: *mut IO_STATUS_BLOCK + ) -> NTSTATUS; + + /// Opens or creates a file handle. + #[allow(clippy::too_many_arguments)] + fn NtCreateFile( + FileHandle: *mut HANDLE, + DesiredAccess: u32, + ObjectAttributes: *mut OBJECT_ATTRIBUTES, + IoStatusBlock: *mut IO_STATUS_BLOCK, + AllocationSize: *mut i64, + FileAttributes: u32, + ShareAccess: u32, + CreateDisposition: u32, + CreateOptions: u32, + EaBuffer: *mut (), + EaLength: u32 + ) -> NTSTATUS; + + /// Runs an I/O control on a file handle. + /// + /// Practically equivalent to `ioctl`. + #[allow(clippy::too_many_arguments)] + fn NtDeviceIoControlFile( + FileHandle: HANDLE, + Event: HANDLE, + ApcRoutine: *mut (), + ApcContext: *mut (), + IoStatusBlock: *mut IO_STATUS_BLOCK, + IoControlCode: u32, + InputBuffer: *mut (), + InputBufferLength: u32, + OutputBuffer: *mut (), + OutputBufferLength: u32 + ) -> NTSTATUS; + + /// Converts `NTSTATUS` to a DOS error code. + fn RtlNtStatusToDosError( + Status: NTSTATUS + ) -> u32; +} + +impl NtdllImports { + fn get() -> io::Result<&'static Self> { + macro_rules! s { + ($e:expr) => {{ + $e as u16 + }}; + } + + // ntdll.dll + static NTDLL_NAME: &[u16] = &[ + s!('n'), + s!('t'), + s!('d'), + s!('l'), + s!('l'), + s!('.'), + s!('d'), + s!('l'), + s!('l'), + s!('\0'), + ]; + static NTDLL_IMPORTS: OnceCell> = OnceCell::new(); + + NTDLL_IMPORTS + .get_or_init(|| unsafe { + let ntdll = GetModuleHandleW(NTDLL_NAME.as_ptr() as *const _); + + if ntdll == 0 { + log::error!("Failed to load ntdll.dll"); + return Err(io::Error::last_os_error()); + } + + NtdllImports::load(ntdll) + }) + .as_ref() + .map_err(|e| io::Error::from(e.kind())) + } + + pub(super) fn force_load() -> io::Result<()> { + Self::get()?; + Ok(()) + } +} + +/// The handle to the AFD device. +pub(super) struct Afd { + /// The handle to the AFD device. + handle: HANDLE, + + /// We own `T`. + _marker: PhantomData, +} + +impl fmt::Debug for Afd { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct WriteAsHex(HANDLE); + + impl fmt::Debug for WriteAsHex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:010x}", self.0) + } + } + + f.debug_struct("Afd") + .field("handle", &WriteAsHex(self.handle)) + .finish() + } +} + +impl Drop for Afd { + fn drop(&mut self) { + unsafe { + CloseHandle(self.handle); + } + } +} + +impl AsRawHandle for Afd { + fn as_raw_handle(&self) -> RawHandle { + self.handle as _ + } +} + +impl Afd +where + T::Completion: AsIoStatusBlock + HasAfdInfo, +{ + /// Create a new AFD device. + pub(super) fn new() -> io::Result { + macro_rules! s { + ($e:expr) => { + ($e) as u16 + }; + } + + /// \Device\Afd\Smol + const AFD_NAME: &[u16] = &[ + s!('\\'), + s!('D'), + s!('e'), + s!('v'), + s!('i'), + s!('c'), + s!('e'), + s!('\\'), + s!('A'), + s!('f'), + s!('d'), + s!('\\'), + s!('S'), + s!('m'), + s!('o'), + s!('l'), + s!('\0'), + ]; + + // Set up device attributes. + let mut device_name = UNICODE_STRING { + Length: (AFD_NAME.len() * size_of::()) as u16, + MaximumLength: (AFD_NAME.len() * size_of::()) as u16, + Buffer: AFD_NAME.as_ptr() as *mut _, + }; + let mut device_attributes = OBJECT_ATTRIBUTES { + Length: size_of::() as u32, + RootDirectory: 0, + ObjectName: &mut device_name, + Attributes: 0, + SecurityDescriptor: ptr::null_mut(), + SecurityQualityOfService: ptr::null_mut(), + }; + + let mut handle = MaybeUninit::::uninit(); + let mut iosb = MaybeUninit::::zeroed(); + let ntdll = NtdllImports::get()?; + + let result = unsafe { + ntdll.NtCreateFile( + handle.as_mut_ptr(), + SYNCHRONIZE, + &mut device_attributes, + iosb.as_mut_ptr(), + ptr::null_mut(), + 0, + FILE_SHARE_READ | FILE_SHARE_WRITE, + FILE_OPEN, + 0, + ptr::null_mut(), + 0, + ) + }; + + if result != STATUS_SUCCESS { + let real_code = unsafe { ntdll.RtlNtStatusToDosError(result) }; + + return Err(io::Error::from_raw_os_error(real_code as i32)); + } + + let handle = unsafe { handle.assume_init() }; + + Ok(Self { + handle, + _marker: PhantomData, + }) + } + + /// Begin polling with the provided handle. + pub(super) fn poll( + &self, + packet: T, + base_socket: RawSocket, + afd_events: AfdPollMask, + ) -> io::Result<()> { + const IOCTL_AFD_POLL: u32 = 0x00012024; + + // Lock the packet. + if !packet.get().try_lock() { + return Err(io::Error::new( + io::ErrorKind::WouldBlock, + "packet is already in use", + )); + } + + // Set up the AFD poll info. + let poll_info = unsafe { + let poll_info = Pin::into_inner_unchecked(packet.get().afd_info()).get(); + + // Initialize the AFD poll info. + (*poll_info).exclusive = false.into(); + (*poll_info).handle_count = 1; + (*poll_info).timeout = std::i64::MAX; + (*poll_info).handles[0].handle = base_socket as HANDLE; + (*poll_info).handles[0].status = 0; + (*poll_info).handles[0].events = afd_events; + + poll_info + }; + + let iosb = T::into_ptr(packet).cast::(); + // Set Status to pending + unsafe { + (*iosb).Anonymous.Status = STATUS_PENDING; + } + + let ntdll = NtdllImports::get()?; + let result = unsafe { + ntdll.NtDeviceIoControlFile( + self.handle, + 0, + ptr::null_mut(), + iosb.cast(), + iosb.cast(), + IOCTL_AFD_POLL, + poll_info.cast(), + size_of::() as u32, + poll_info.cast(), + size_of::() as u32, + ) + }; + + match result { + STATUS_SUCCESS => Ok(()), + STATUS_PENDING => Err(io::ErrorKind::WouldBlock.into()), + status => { + let real_code = unsafe { ntdll.RtlNtStatusToDosError(status) }; + + Err(io::Error::from_raw_os_error(real_code as i32)) + } + } + } + + /// Cancel an ongoing poll operation. + /// + /// # Safety + /// + /// The poll operation must currently be in progress for this AFD. + pub(super) unsafe fn cancel(&self, packet: &T) -> io::Result<()> { + let ntdll = NtdllImports::get()?; + + let result = { + // First, check if the packet is still in use. + let iosb = packet.as_ptr().cast::(); + + if (*iosb).Anonymous.Status != STATUS_PENDING { + return Ok(()); + } + + // Cancel the packet. + let mut cancel_iosb = MaybeUninit::::zeroed(); + + ntdll.NtCancelIoFileEx(self.handle, iosb, cancel_iosb.as_mut_ptr()) + }; + + if result == STATUS_SUCCESS || result == STATUS_NOT_FOUND { + Ok(()) + } else { + let real_code = ntdll.RtlNtStatusToDosError(result); + + Err(io::Error::from_raw_os_error(real_code as i32)) + } + } +} + +/// A one-time initialization cell. +struct OnceCell { + /// The value. + value: UnsafeCell>, + + /// The one-time initialization. + once: Once, +} + +unsafe impl Send for OnceCell {} +unsafe impl Sync for OnceCell {} + +impl OnceCell { + /// Creates a new `OnceCell`. + pub const fn new() -> Self { + OnceCell { + value: UnsafeCell::new(MaybeUninit::uninit()), + once: Once::new(), + } + } + + /// Gets the value or initializes it. + pub fn get_or_init(&self, f: F) -> &T + where + F: FnOnce() -> T, + { + self.once.call_once(|| unsafe { + let value = f(); + *self.value.get() = MaybeUninit::new(value); + }); + + unsafe { &*self.value.get().cast() } + } +} + +pin_project_lite::pin_project! { + /// An I/O status block paired with some auxillary data. + #[repr(C)] + pub(super) struct IoStatusBlock { + // The I/O status block. + iosb: UnsafeCell, + + // Whether or not the block is in use. + in_use: AtomicBool, + + // The auxillary data. + #[pin] + data: T, + + // This block is not allowed to move. + #[pin] + _marker: PhantomPinned, + } +} + +impl fmt::Debug for IoStatusBlock { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IoStatusBlock") + .field("iosb", &"..") + .field("in_use", &self.in_use) + .field("data", &self.data) + .finish() + } +} + +impl From for IoStatusBlock { + fn from(data: T) -> Self { + Self { + iosb: UnsafeCell::new(unsafe { std::mem::zeroed() }), + in_use: AtomicBool::new(false), + data, + _marker: PhantomPinned, + } + } +} + +impl IoStatusBlock { + pub(super) fn iosb(self: Pin<&Self>) -> &UnsafeCell { + self.project_ref().iosb + } + + pub(super) fn data(self: Pin<&Self>) -> Pin<&T> { + self.project_ref().data + } +} + +impl HasAfdInfo for IoStatusBlock { + fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell> { + self.project_ref().data.afd_info() + } +} + +/// Can be transmuted to an I/O status block. +/// +/// # Safety +/// +/// A pointer to `T` must be able to be converted to a pointer to `IO_STATUS_BLOCK` +/// without any issues. +pub(super) unsafe trait AsIoStatusBlock {} + +unsafe impl AsIoStatusBlock for IoStatusBlock {} +unsafe impl Completion for IoStatusBlock { + fn try_lock(self: Pin<&Self>) -> bool { + !self.in_use.swap(true, Ordering::SeqCst) + } + + unsafe fn unlock(self: Pin<&Self>) { + self.in_use.store(false, Ordering::SeqCst); + } +} + +/// Get the base socket associated with a socket. +pub(super) fn base_socket(sock: RawSocket) -> io::Result { + // First, try the SIO_BASE_HANDLE ioctl. + let result = unsafe { try_socket_ioctl(sock, SIO_BASE_HANDLE) }; + + match result { + Ok(sock) => return Ok(sock), + Err(e) if e.kind() == io::ErrorKind::InvalidInput => return Err(e), + Err(_) => {} + } + + // Some poorly coded LSPs may not handle SIO_BASE_HANDLE properly, but in some cases may + // handle SIO_BSP_HANDLE_POLL better. Try that. + let result = unsafe { try_socket_ioctl(sock, SIO_BSP_HANDLE_POLL)? }; + if result == sock { + return Err(io::Error::from(io::ErrorKind::InvalidInput)); + } + + // Try `SIO_BASE_HANDLE` again, in case the LSP fixed itself. + unsafe { try_socket_ioctl(result, SIO_BASE_HANDLE) } +} + +/// Run an IOCTL on a socket and return a socket. +/// +/// # Safety +/// +/// The `ioctl` parameter must be a valid I/O control that returns a valid socket. +unsafe fn try_socket_ioctl(sock: RawSocket, ioctl: u32) -> io::Result { + let mut out = MaybeUninit::::uninit(); + let mut bytes = 0u32; + + let result = WSAIoctl( + sock as _, + ioctl, + ptr::null_mut(), + 0, + out.as_mut_ptr().cast(), + size_of::() as u32, + &mut bytes, + ptr::null_mut(), + None, + ); + + if result == SOCKET_ERROR { + return Err(io::Error::last_os_error()); + } + + Ok(out.assume_init()) +} diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs new file mode 100644 index 0000000..cc618b1 --- /dev/null +++ b/src/iocp/mod.rs @@ -0,0 +1,834 @@ +//! Bindings to Windows I/O Completion Ports. +//! +//! I/O Completion Ports is a completion-based API rather than a polling-based API, like +//! epoll or kqueue. Therefore, we have to adapt the IOCP API to the crate's API. +//! +//! WinSock is powered by the Auxillary Function Driver (AFD) subsystem, which can be +//! accessed directly by using unstable `ntdll` functions. AFD exposes features that are not +//! available through the normal WinSock interface, such as IOCTL_AFD_POLL. This function is +//! similar to the exposed `WSAPoll` method. However, once the targeted socket is "ready", +//! a completion packet is queued to an I/O completion port. +//! +//! We take advantage of IOCTL_AFD_POLL to "translate" this crate's polling-based API +//! to the one Windows expects. When a device is added to the `Poller`, an IOCTL_AFD_POLL +//! operation is started and queued to the IOCP. To modify a currently registered device +//! (e.g. with `modify()` or `delete()`), the ongoing POLL is cancelled and then restarted +//! with new parameters. Whn the POLL eventually completes, the packet is posted to the IOCP. +//! From here it's a simple matter of using `GetQueuedCompletionStatusEx` to read the packets +//! from the IOCP and react accordingly. Notifying the poller is trivial, because we can +//! simply post a packet to the IOCP to wake it up. +//! +//! The main disadvantage of this strategy is that it relies on unstable Windows APIs. +//! However, as `libuv` (the backing I/O library for Node.JS) relies on the same unstable +//! AFD strategy, it is unlikely to be broken without plenty of advanced warning. +//! +//! Previously, this crate used the `wepoll` library for polling. `wepoll` uses a similar +//! AFD-based strategy for polling. + +mod afd; +mod port; + +use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock}; +use port::{IoCompletionPort, OverlappedEntry}; +use windows_sys::Win32::Foundation::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED}; + +use crate::{Event, PollMode}; + +use concurrent_queue::ConcurrentQueue; +use pin_project_lite::pin_project; + +use std::cell::UnsafeCell; +use std::collections::hash_map::{Entry, HashMap}; +use std::fmt; +use std::io; +use std::marker::PhantomPinned; +use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket}; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; +use std::time::{Duration, Instant}; + +#[cfg(not(polling_no_io_safety))] +use std::os::windows::io::{AsHandle, BorrowedHandle}; + +/// Macro to lock and ignore lock poisoning. +macro_rules! lock { + ($lock_result:expr) => {{ + $lock_result.unwrap_or_else(|e| e.into_inner()) + }}; +} + +/// Interface to I/O completion ports. +#[derive(Debug)] +pub(super) struct Poller { + /// The I/O completion port. + port: IoCompletionPort, + + /// List of currently active AFD instances. + /// + /// Weak references are kept here so that the AFD handle is automatically dropped + /// when the last associated socket is dropped. + afd: Mutex>>>, + + /// The state of the sources registered with this poller. + sources: RwLock>, + + /// Sockets with pending updates. + pending_updates: ConcurrentQueue, + + /// Are we currently polling? + polling: AtomicBool, + + /// A list of completion packets. + packets: Mutex>>, + + /// The packet used to notify the poller. + notifier: Packet, +} + +unsafe impl Send for Poller {} +unsafe impl Sync for Poller {} + +impl Poller { + /// Creates a new poller. + pub(super) fn new() -> io::Result { + // Make sure AFD is able to be used. + if let Err(e) = afd::NtdllImports::force_load() { + return Err(crate::unsupported_error(format!( + "Failed to initialize unstable Windows functions: {}\nThis usually only happens for old Windows or Wine.", + e + ))); + } + + // Create and destroy a single AFD to test if we support it. + Afd::::new().map_err(|e| crate::unsupported_error(format!( + "Failed to initialize \\Device\\Afd: {}\nThis usually only happens for old Windows or Wine.", + e, + )))?; + + let port = IoCompletionPort::new(0)?; + + log::trace!("new: handle={:?}", &port); + + Ok(Poller { + port, + afd: Mutex::new(vec![]), + sources: RwLock::new(HashMap::new()), + pending_updates: ConcurrentQueue::bounded(1024), + polling: AtomicBool::new(false), + packets: Mutex::new(Vec::with_capacity(1024)), + notifier: Arc::pin( + PacketInner::Wakeup { + _pinned: PhantomPinned, + } + .into(), + ), + }) + } + + /// Whether this poller supports level-triggered events. + pub(super) fn supports_level(&self) -> bool { + true + } + + /// Whether this poller supports edge-triggered events. + pub(super) fn supports_edge(&self) -> bool { + false + } + + /// Add a new source to the poller. + pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> { + log::trace!( + "add: handle={:?}, sock={}, ev={:?}", + self.port, + socket, + interest + ); + + // We don't support edge-triggered events. + if matches!(mode, PollMode::Edge) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "edge-triggered events are not supported", + )); + } + + // Create a new packet. + let socket_state = { + let state = SocketState { + socket, + base_socket: base_socket(socket)?, + interest, + interest_error: true, + afd: self.afd_handle()?, + mode, + waiting_on_delete: false, + status: SocketStatus::Idle, + }; + + Arc::pin(IoStatusBlock::from(PacketInner::Socket { + packet: UnsafeCell::new(AfdPollInfo::default()), + socket: Mutex::new(state), + })) + }; + + // Keep track of the source in the poller. + { + let mut sources = lock!(self.sources.write()); + + match sources.entry(socket) { + Entry::Vacant(v) => { + v.insert(Pin::>::clone(&socket_state)); + } + + Entry::Occupied(_) => { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + } + } + + // Update the packet. + self.update_packet(socket_state) + } + + /// Update a source in the poller. + pub(super) fn modify( + &self, + socket: RawSocket, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + log::trace!( + "modify: handle={:?}, sock={}, ev={:?}", + self.port, + socket, + interest + ); + + // We don't support edge-triggered events. + if matches!(mode, PollMode::Edge) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "edge-triggered events are not supported", + )); + } + + // Get a reference to the source. + let source = { + let sources = lock!(self.sources.read()); + + sources + .get(&socket) + .cloned() + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? + }; + + // Set the new event. + if source.as_ref().set_events(interest, mode) { + self.update_packet(source)?; + } + + Ok(()) + } + + /// Delete a source from the poller. + pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> { + log::trace!("remove: handle={:?}, sock={}", self.port, socket); + + // Get a reference to the source. + let source = { + let mut sources = lock!(self.sources.write()); + + match sources.remove(&socket) { + Some(s) => s, + None => { + // If the source has already been removed, then we can just return. + return Ok(()); + } + } + }; + + // Indicate to the source that it is being deleted. + // This cancels any ongoing AFD_IOCTL_POLL operations. + source.begin_delete() + } + + /// Wait for events. + pub(super) fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { + log::trace!("wait: handle={:?}, timeout={:?}", self.port, timeout); + + let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout)); + let mut packets = lock!(self.packets.lock()); + let mut notified = false; + events.packets.clear(); + + loop { + let mut new_events = 0; + + // Indicate that we are now polling. + let was_polling = self.polling.swap(true, Ordering::SeqCst); + debug_assert!(!was_polling); + + let guard = CallOnDrop(|| { + let was_polling = self.polling.swap(false, Ordering::SeqCst); + debug_assert!(was_polling); + }); + + // Process every entry in the queue before we start polling. + self.drain_update_queue(false)?; + + // Get the time to wait for. + let timeout = deadline.map(|t| t.saturating_duration_since(Instant::now())); + + // Wait for I/O events. + let len = self.port.wait(&mut packets, timeout)?; + log::trace!("new events: handle={:?}, len={}", self.port, len); + + // We are no longer polling. + drop(guard); + + // Process all of the events. + for entry in packets.drain(..) { + let packet = entry.into_packet(); + + // Feed the event into the packet. + match packet.feed_event(self)? { + FeedEventResult::NoEvent => {} + FeedEventResult::Event(event) => { + events.packets.push(event); + new_events += 1; + } + FeedEventResult::Notified => { + notified = true; + } + } + } + + // Break if there was a notification or at least one event, or if deadline is reached. + let timeout_is_empty = + timeout.map_or(false, |t| t.as_secs() == 0 && t.subsec_nanos() == 0); + if notified || new_events > 0 || timeout_is_empty { + break; + } + + log::trace!("wait: no events found, re-entering polling loop"); + } + + Ok(()) + } + + /// Notify this poller. + pub(super) fn notify(&self) -> io::Result<()> { + // Push the notify packet into the IOCP. + self.port.post(0, 0, self.notifier.clone()) + } + + /// Run an update on a packet. + fn update_packet(&self, mut packet: Packet) -> io::Result<()> { + loop { + // If we are currently polling, we need to update the packet immediately. + if self.polling.load(Ordering::Acquire) { + packet.update()?; + return Ok(()); + } + + // Try to queue the update. + match self.pending_updates.push(packet) { + Ok(()) => return Ok(()), + Err(p) => packet = p.into_inner(), + } + + // If we failed to queue the update, we need to drain the queue first. + self.drain_update_queue(true)?; + } + } + + /// Drain the update queue. + fn drain_update_queue(&self, limit: bool) -> io::Result<()> { + let max = if limit { + self.pending_updates.capacity().unwrap() + } else { + std::usize::MAX + }; + + // Only drain the queue's capacity, since this could in theory run forever. + core::iter::from_fn(|| self.pending_updates.pop().ok()) + .take(max) + .try_for_each(|packet| packet.update()) + } + + /// Get a handle to the AFD reference. + fn afd_handle(&self) -> io::Result>> { + const AFD_MAX_SIZE: usize = 32; + + // Crawl the list and see if there are any existing AFD instances that we can use. + // Remove any unused AFD pointers. + let mut afd_handles = lock!(self.afd.lock()); + let mut i = 0; + while i < afd_handles.len() { + // Get the reference count of the AFD instance. + let refcount = Weak::strong_count(&afd_handles[i]); + + match refcount { + 0 => { + // Prune the AFD pointer if it has no references. + afd_handles.swap_remove(i); + } + + refcount if refcount >= AFD_MAX_SIZE => { + // Skip this one, since it is already at the maximum size. + i += 1; + } + + _ => { + // We can use this AFD instance. + match afd_handles[i].upgrade() { + Some(afd) => return Ok(afd), + None => { + // The last socket dropped the AFD before we could acquire it. + // Prune the AFD pointer and continue. + afd_handles.swap_remove(i); + } + } + } + } + } + + // No available handles, create a new AFD instance. + let afd = Arc::new(Afd::new()?); + + // Register the AFD instance with the I/O completion port. + self.port.register(&*afd, true)?; + + // Insert a weak pointer to the AFD instance into the list. + afd_handles.push(Arc::downgrade(&afd)); + + Ok(afd) + } +} + +impl AsRawHandle for Poller { + fn as_raw_handle(&self) -> RawHandle { + self.port.as_raw_handle() + } +} + +#[cfg(not(polling_no_io_safety))] +impl AsHandle for Poller { + fn as_handle(&self) -> BorrowedHandle<'_> { + unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) } + } +} + +/// The container for events. +pub(super) struct Events { + /// List of IOCP packets. + packets: Vec, +} + +unsafe impl Send for Events {} + +impl Events { + /// Creates an empty list of events. + pub(super) fn new() -> Events { + Events { + packets: Vec::with_capacity(1024), + } + } + + /// Iterate over I/O events. + pub(super) fn iter(&self) -> impl Iterator + '_ { + self.packets.iter().copied() + } +} + +/// The type of our completion packet. +type Packet = Pin>; +type PacketUnwrapped = IoStatusBlock; + +pin_project! { + /// The inner type of the packet. + #[project_ref = PacketInnerProj] + #[project = PacketInnerProjMut] + enum PacketInner { + // A packet for a socket. + Socket { + // The AFD packet state. + #[pin] + packet: UnsafeCell, + + // The socket state. + socket: Mutex + }, + + // A packet used to wake up the poller. + Wakeup { #[pin] _pinned: PhantomPinned }, + } +} + +impl fmt::Debug for PacketInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Wakeup { .. } => f.write_str("Wakeup { .. }"), + Self::Socket { socket, .. } => f + .debug_struct("Socket") + .field("packet", &"..") + .field("socket", socket) + .finish(), + } + } +} + +impl HasAfdInfo for PacketInner { + fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell> { + match self.project_ref() { + PacketInnerProj::Socket { packet, .. } => packet, + PacketInnerProj::Wakeup { .. } => unreachable!(), + } + } +} + +impl PacketUnwrapped { + /// Set the new events that this socket is waiting on. + /// + /// Returns `true` if we need to be updated. + fn set_events(self: Pin<&Self>, interest: Event, mode: PollMode) -> bool { + let mut socket = match self.socket_state() { + Some(s) => s, + None => return false, + }; + + socket.interest = interest; + socket.mode = mode; + socket.interest_error = true; + + match socket.status { + SocketStatus::Polling { readable, writable } => { + (interest.readable && !readable) || (interest.writable && !writable) + } + _ => true, + } + } + + /// Update the socket and install the new status in AFD. + fn update(self: Pin>) -> io::Result<()> { + let mut socket = match self.as_ref().socket_state() { + Some(s) => s, + None => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")), + }; + + // If we are waiting on a delete, just return, dropping the packet. + if socket.waiting_on_delete { + return Ok(()); + } + + // Check the current status. + match socket.status { + SocketStatus::Polling { readable, writable } => { + // If we need to poll for events aside from what we are currently polling, we need + // to update the packet. Cancel the ongoing poll. + if (socket.interest.readable && !readable) + || (socket.interest.writable && !writable) + { + return self.cancel(socket); + } + + // All events that we are currently waiting on are accounted for. + Ok(()) + } + + SocketStatus::Cancelled => { + // The ongoing operation was cancelled, and we're still waiting for it to return. + // For now, wait until the top-level loop calls feed_event(). + Ok(()) + } + + SocketStatus::Idle => { + // Start a new poll. + let result = socket.afd.poll( + self.clone(), + socket.base_socket, + event_to_afd_mask( + socket.interest.readable, + socket.interest.writable, + socket.interest_error, + ), + ); + + match result { + Ok(()) => {} + + Err(err) + if err.raw_os_error() == Some(ERROR_IO_PENDING as i32) + || err.kind() == io::ErrorKind::WouldBlock => + { + // The operation is pending. + } + + Err(err) if err.raw_os_error() == Some(ERROR_INVALID_HANDLE as i32) => { + // The socket was closed. We need to delete it. + // This should happen after we drop it here. + } + + Err(err) => return Err(err), + } + + // We are now polling for the current events. + socket.status = SocketStatus::Polling { + readable: socket.interest.readable, + writable: socket.interest.writable, + }; + + Ok(()) + } + } + } + + /// This socket state was notified; see if we need to update it. + fn feed_event(self: Pin>, poller: &Poller) -> io::Result { + let inner = self.as_ref().data().project_ref(); + + let (afd_info, socket) = match inner { + PacketInnerProj::Socket { packet, socket } => (packet, socket), + PacketInnerProj::Wakeup { .. } => { + // The poller was notified. + return Ok(FeedEventResult::Notified); + } + }; + + let mut socket_state = lock!(socket.lock()); + let mut event = Event::none(socket_state.interest.key); + + // Put ourselves into the idle state. + socket_state.status = SocketStatus::Idle; + + // If we are waiting to be deleted, just return and let the drop handler do their thing. + if socket_state.waiting_on_delete { + return Ok(FeedEventResult::NoEvent); + } + + unsafe { + // SAFETY: The packet is not in transit. + let iosb = &mut *self.as_ref().iosb().get(); + + // Check the status. + match iosb.Anonymous.Status { + STATUS_CANCELLED => { + // Poll request was cancelled. + } + + status if status < 0 => { + // There was an error, so we signal both ends. + event.readable = true; + event.writable = true; + } + + _ => { + // Check in on the AFD data. + let afd_data = &*afd_info.get(); + + if afd_data.handle_count() >= 1 { + let events = afd_data.events(); + + // If we closed the socket, remove it from being polled. + if events.contains(AfdPollMask::LOCAL_CLOSE) { + let source = lock!(poller.sources.write()) + .remove(&socket_state.socket) + .unwrap(); + return source.begin_delete().map(|()| FeedEventResult::NoEvent); + } + + // Report socket-related events. + let (readable, writable) = afd_mask_to_event(events); + event.readable = readable; + event.writable = writable; + } + } + } + } + + // Filter out events that the user didn't ask for. + event.readable &= socket_state.interest.readable; + event.writable &= socket_state.interest.writable; + + // If this event doesn't have anything that interests us, don't return or + // update the oneshot state. + let return_value = if event.readable || event.writable { + // If we are in oneshot mode, remove the interest. + if matches!(socket_state.mode, PollMode::Oneshot) { + socket_state.interest = Event::none(socket_state.interest.key); + socket_state.interest_error = false; + } + + FeedEventResult::Event(event) + } else { + FeedEventResult::NoEvent + }; + + // Put ourselves in the update queue. + drop(socket_state); + poller.update_packet(self)?; + + // Return the event. + Ok(return_value) + } + + /// Begin deleting this socket. + fn begin_delete(self: Pin>) -> io::Result<()> { + // If we aren't already being deleted, start deleting. + let mut socket = self + .as_ref() + .socket_state() + .expect("can't delete packet that doesn't belong to a socket"); + if !socket.waiting_on_delete { + socket.waiting_on_delete = true; + + if matches!(socket.status, SocketStatus::Polling { .. }) { + // Cancel the ongoing poll. + self.cancel(socket)?; + } + } + + // Either drop it now or wait for it to be dropped later. + Ok(()) + } + + fn cancel(self: &Pin>, mut socket: MutexGuard<'_, SocketState>) -> io::Result<()> { + assert!(matches!(socket.status, SocketStatus::Polling { .. })); + + // Send the cancel request. + unsafe { + socket.afd.cancel(self)?; + } + + // Move state to cancelled. + socket.status = SocketStatus::Cancelled; + + Ok(()) + } + + fn socket_state(self: Pin<&Self>) -> Option> { + let inner = self.data().project_ref(); + + let state = match inner { + PacketInnerProj::Wakeup { .. } => return None, + PacketInnerProj::Socket { socket, .. } => socket, + }; + + Some(lock!(state.lock())) + } +} + +/// Per-socket state. +#[derive(Debug)] +struct SocketState { + /// The raw socket handle. + socket: RawSocket, + + /// The base socket handle. + base_socket: RawSocket, + + /// The event that this socket is currently waiting on. + interest: Event, + + /// Whether to listen for error events. + interest_error: bool, + + /// The current poll mode. + mode: PollMode, + + /// The AFD instance that this socket is registered with. + afd: Arc>, + + /// Whether this socket is waiting to be deleted. + waiting_on_delete: bool, + + /// The current status of the socket. + status: SocketStatus, +} + +/// The mode that a socket can be in. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum SocketStatus { + /// We are currently not polling. + Idle, + + /// We are currently polling these events. + Polling { + /// We are currently polling for readable events. + readable: bool, + + /// We are currently polling for writable events. + writable: bool, + }, + + /// The last poll operation was cancelled, and we're waiting for it to + /// complete. + Cancelled, +} + +/// The result of calling `feed_event`. +#[derive(Debug)] +enum FeedEventResult { + /// No event was yielded. + NoEvent, + + /// An event was yielded. + Event(Event), + + /// The poller has been notified. + Notified, +} + +fn event_to_afd_mask(readable: bool, writable: bool, error: bool) -> afd::AfdPollMask { + use afd::AfdPollMask as AfdPoll; + + let mut mask = AfdPoll::empty(); + + if error || readable || writable { + mask |= AfdPoll::ABORT | AfdPoll::CONNECT_FAIL; + } + + if readable { + mask |= + AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED; + } + + if writable { + mask |= AfdPoll::SEND; + } + + mask +} + +fn afd_mask_to_event(mask: afd::AfdPollMask) -> (bool, bool) { + use afd::AfdPollMask as AfdPoll; + + let mut readable = false; + let mut writable = false; + + if mask.intersects( + AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED, + ) { + readable = true; + } + + if mask.intersects(AfdPoll::SEND) { + writable = true; + } + + if mask.intersects(AfdPoll::ABORT | AfdPoll::CONNECT_FAIL) { + readable = true; + writable = true; + } + + (readable, writable) +} + +struct CallOnDrop(F); + +impl Drop for CallOnDrop { + fn drop(&mut self) { + (self.0)(); + } +} diff --git a/src/iocp/port.rs b/src/iocp/port.rs new file mode 100644 index 0000000..3feae07 --- /dev/null +++ b/src/iocp/port.rs @@ -0,0 +1,327 @@ +//! A safe wrapper around the Windows I/O API. + +use std::convert::{TryFrom, TryInto}; +use std::fmt; +use std::io; +use std::marker::PhantomData; +use std::mem::MaybeUninit; +use std::ops::Deref; +use std::os::windows::io::{AsRawHandle, RawHandle}; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use windows_sys::Win32::Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE}; +use windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes; +use windows_sys::Win32::System::WindowsProgramming::{FILE_SKIP_SET_EVENT_ON_HANDLE, INFINITE}; +use windows_sys::Win32::System::IO::{ + CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, OVERLAPPED, + OVERLAPPED_ENTRY, +}; + +/// A completion block which can be used with I/O completion ports. +/// +/// # Safety +/// +/// This must be a valid completion block. +pub(super) unsafe trait Completion { + /// Signal to the completion block that we are about to start an operation. + fn try_lock(self: Pin<&Self>) -> bool; + + /// Unlock the completion block. + unsafe fn unlock(self: Pin<&Self>); +} + +/// The pointer to a completion block. +/// +/// # Safety +/// +/// This must be a valid completion block. +pub(super) unsafe trait CompletionHandle: Deref + Sized { + /// Type of the completion block. + type Completion: Completion; + + /// Get a pointer to the completion block. + /// + /// The pointer is pinned since the underlying object should not be moved + /// after creation. This prevents it from being invalidated while it's + /// used in an overlapped operation. + fn get(&self) -> Pin<&Self::Completion>; + + /// Convert this block into a pointer that can be passed as `*mut OVERLAPPED`. + fn into_ptr(this: Self) -> *mut OVERLAPPED; + + /// Convert a pointer that was passed as `*mut OVERLAPPED` into a pointer to this block. + /// + /// # Safety + /// + /// This must be a valid pointer to a completion block. + unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self; + + /// Convert to a pointer without losing ownership. + fn as_ptr(&self) -> *mut OVERLAPPED; +} + +unsafe impl<'a, T: Completion> CompletionHandle for Pin<&'a T> { + type Completion = T; + + fn get(&self) -> Pin<&Self::Completion> { + *self + } + + fn into_ptr(this: Self) -> *mut OVERLAPPED { + unsafe { Pin::into_inner_unchecked(this) as *const T as *mut OVERLAPPED } + } + + unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self { + Pin::new_unchecked(&*(ptr as *const T)) + } + + fn as_ptr(&self) -> *mut OVERLAPPED { + self.get_ref() as *const T as *mut OVERLAPPED + } +} + +unsafe impl CompletionHandle for Pin> { + type Completion = T; + + fn get(&self) -> Pin<&Self::Completion> { + self.as_ref() + } + + fn into_ptr(this: Self) -> *mut OVERLAPPED { + unsafe { Arc::into_raw(Pin::into_inner_unchecked(this)) as *const T as *mut OVERLAPPED } + } + + unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self { + Pin::new_unchecked(Arc::from_raw(ptr as *const T)) + } + + fn as_ptr(&self) -> *mut OVERLAPPED { + self.as_ref().get_ref() as *const T as *mut OVERLAPPED + } +} + +/// A handle to the I/O completion port. +pub(super) struct IoCompletionPort { + /// The underlying handle. + handle: HANDLE, + + /// We own the status block. + _marker: PhantomData, +} + +impl Drop for IoCompletionPort { + fn drop(&mut self) { + unsafe { + CloseHandle(self.handle); + } + } +} + +impl AsRawHandle for IoCompletionPort { + fn as_raw_handle(&self) -> RawHandle { + self.handle as _ + } +} + +impl fmt::Debug for IoCompletionPort { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct WriteAsHex(HANDLE); + + impl fmt::Debug for WriteAsHex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:010x}", self.0) + } + } + + f.debug_struct("IoCompletionPort") + .field("handle", &WriteAsHex(self.handle)) + .finish() + } +} + +impl IoCompletionPort { + /// Create a new I/O completion port. + pub(super) fn new(threads: usize) -> io::Result { + let handle = unsafe { + CreateIoCompletionPort( + INVALID_HANDLE_VALUE, + 0, + 0, + threads.try_into().expect("too many threads"), + ) + }; + + if handle == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(Self { + handle, + _marker: PhantomData, + }) + } + } + + /// Register a handle with this I/O completion port. + pub(super) fn register( + &self, + handle: &impl AsRawHandle, // TODO change to AsHandle + skip_set_event_on_handle: bool, + ) -> io::Result<()> { + let handle = handle.as_raw_handle(); + + let result = + unsafe { CreateIoCompletionPort(handle as _, self.handle, handle as usize, 0) }; + + if result == 0 { + return Err(io::Error::last_os_error()); + } + + if skip_set_event_on_handle { + // Set the skip event on handle. + let result = unsafe { + SetFileCompletionNotificationModes(handle as _, FILE_SKIP_SET_EVENT_ON_HANDLE as _) + }; + + if result == 0 { + return Err(io::Error::last_os_error()); + } + } + + Ok(()) + } + + /// Post a completion packet to this port. + pub(super) fn post(&self, bytes_transferred: usize, id: usize, packet: T) -> io::Result<()> { + let result = unsafe { + PostQueuedCompletionStatus( + self.handle, + bytes_transferred + .try_into() + .expect("too many bytes transferred"), + id, + T::into_ptr(packet), + ) + }; + + if result == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } + } + + /// Wait for completion packets to arrive. + pub(super) fn wait( + &self, + packets: &mut Vec>, + timeout: Option, + ) -> io::Result { + // Drop the current packets. + packets.clear(); + + let mut count = MaybeUninit::::uninit(); + let timeout = timeout.map_or(INFINITE, dur2timeout); + + let result = unsafe { + GetQueuedCompletionStatusEx( + self.handle, + packets.as_mut_ptr() as _, + packets.capacity().try_into().expect("too many packets"), + count.as_mut_ptr(), + timeout, + 0, + ) + }; + + if result == 0 { + let io_error = io::Error::last_os_error(); + if io_error.kind() == io::ErrorKind::TimedOut { + Ok(0) + } else { + Err(io_error) + } + } else { + let count = unsafe { count.assume_init() }; + unsafe { + packets.set_len(count as _); + } + Ok(count as _) + } + } +} + +/// An `OVERLAPPED_ENTRY` resulting from an I/O completion port. +#[repr(transparent)] +pub(super) struct OverlappedEntry { + /// The underlying entry. + entry: OVERLAPPED_ENTRY, + + /// We own the status block. + _marker: PhantomData, +} + +impl fmt::Debug for OverlappedEntry { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("OverlappedEntry { .. }") + } +} + +impl OverlappedEntry { + /// Convert into the completion packet. + pub(super) fn into_packet(self) -> T { + let packet = unsafe { self.packet() }; + std::mem::forget(self); + packet + } + + /// Get the packet reference that this entry refers to. + /// + /// # Safety + /// + /// This function should only be called once, since it moves + /// out the `T` from the `OVERLAPPED_ENTRY`. + unsafe fn packet(&self) -> T { + let packet = T::from_ptr(self.entry.lpOverlapped); + packet.get().unlock(); + packet + } +} + +impl Drop for OverlappedEntry { + fn drop(&mut self) { + drop(unsafe { self.packet() }); + } +} + +// Implementation taken from https://github.com/rust-lang/rust/blob/db5476571d9b27c862b95c1e64764b0ac8980e23/src/libstd/sys/windows/mod.rs +fn dur2timeout(dur: Duration) -> u32 { + // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the + // timeouts in windows APIs are typically u32 milliseconds. To translate, we + // have two pieces to take care of: + // + // * Nanosecond precision is rounded up + // * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE + // (never time out). + dur.as_secs() + .checked_mul(1000) + .and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000)) + .and_then(|ms| { + if dur.subsec_nanos() % 1_000_000 > 0 { + ms.checked_add(1) + } else { + Some(ms) + } + }) + .and_then(|x| u32::try_from(x).ok()) + .unwrap_or(INFINITE) +} + +struct CallOnDrop(F); + +impl Drop for CallOnDrop { + fn drop(&mut self) { + (self.0)(); + } +} diff --git a/src/lib.rs b/src/lib.rs index a6795c7..6cbce52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -//! Portable interface to epoll, kqueue, event ports, and wepoll. +//! Portable interface to epoll, kqueue, event ports, and IOCP. //! //! Supported platforms: //! - [epoll](https://en.wikipedia.org/wiki/Epoll): Linux, Android @@ -6,7 +6,7 @@ //! DragonFly BSD //! - [event ports](https://illumos.org/man/port_create): illumos, Solaris //! - [poll](https://en.wikipedia.org/wiki/Poll_(Unix)): VxWorks, Fuchsia, other Unix systems -//! - [wepoll](https://github.com/piscisaureus/wepoll): Windows, Wine (version 7.13+) +//! - [IOCP](https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports): Windows, Wine (version 7.13+) //! //! By default, polling is done in oneshot mode, which means interest in I/O events needs to //! be re-enabled after an event is delivered if we're interested in the next event of the same @@ -113,8 +113,8 @@ cfg_if! { mod poll; use poll as sys; } else if #[cfg(target_os = "windows")] { - mod wepoll; - use wepoll as sys; + mod iocp; + use iocp as sys; } else { compile_error!("polling does not support this target OS"); } diff --git a/src/wepoll.rs b/src/wepoll.rs deleted file mode 100644 index 6c65266..0000000 --- a/src/wepoll.rs +++ /dev/null @@ -1,254 +0,0 @@ -//! Bindings to wepoll (Windows). - -use std::convert::TryInto; -use std::io; -use std::os::raw::c_int; -use std::os::windows::io::{AsRawHandle, RawHandle, RawSocket}; -use std::ptr; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::time::{Duration, Instant}; - -#[cfg(not(polling_no_io_safety))] -use std::os::windows::io::{AsHandle, BorrowedHandle}; - -use wepoll_ffi as we; - -use crate::{Event, PollMode}; - -/// Calls a wepoll function and results in `io::Result`. -macro_rules! wepoll { - ($fn:ident $args:tt) => {{ - let res = unsafe { we::$fn $args }; - if res == -1 { - Err(std::io::Error::last_os_error()) - } else { - Ok(res) - } - }}; -} - -/// Interface to wepoll. -#[derive(Debug)] -pub struct Poller { - handle: we::HANDLE, - notified: AtomicBool, -} - -unsafe impl Send for Poller {} -unsafe impl Sync for Poller {} - -impl Poller { - /// Creates a new poller. - pub fn new() -> io::Result { - let handle = unsafe { we::epoll_create1(0) }; - if handle.is_null() { - return Err(crate::unsupported_error( - format!( - "Failed to initialize Wepoll: {}\nThis usually only happens for old Windows or Wine.", - io::Error::last_os_error() - ) - )); - } - let notified = AtomicBool::new(false); - log::trace!("new: handle={:?}", handle); - Ok(Poller { handle, notified }) - } - - /// Whether this poller supports level-triggered events. - pub fn supports_level(&self) -> bool { - true - } - - /// Whether this poller supports edge-triggered events. - pub fn supports_edge(&self) -> bool { - false - } - - /// Adds a socket. - pub fn add(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> { - log::trace!("add: handle={:?}, sock={}, ev={:?}", self.handle, sock, ev); - self.ctl(we::EPOLL_CTL_ADD, sock, Some((ev, mode))) - } - - /// Modifies a socket. - pub fn modify(&self, sock: RawSocket, ev: Event, mode: PollMode) -> io::Result<()> { - log::trace!( - "modify: handle={:?}, sock={}, ev={:?}", - self.handle, - sock, - ev - ); - self.ctl(we::EPOLL_CTL_MOD, sock, Some((ev, mode))) - } - - /// Deletes a socket. - pub fn delete(&self, sock: RawSocket) -> io::Result<()> { - log::trace!("remove: handle={:?}, sock={}", self.handle, sock); - self.ctl(we::EPOLL_CTL_DEL, sock, None) - } - - /// Waits for I/O events with an optional timeout. - /// - /// Returns the number of processed I/O events. - /// - /// If a notification occurs, this method will return but the notification event will not be - /// included in the `events` list nor contribute to the returned count. - pub fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { - log::trace!("wait: handle={:?}, timeout={:?}", self.handle, timeout); - let deadline = timeout.and_then(|t| Instant::now().checked_add(t)); - - loop { - // Convert the timeout to milliseconds. - let timeout_ms = match deadline.map(|d| d.saturating_duration_since(Instant::now())) { - None => -1, - Some(t) => { - // Round up to a whole millisecond. - let mut ms = t.as_millis().try_into().unwrap_or(std::u64::MAX); - if Duration::from_millis(ms) < t { - ms = ms.saturating_add(1); - } - ms.try_into().unwrap_or(std::i32::MAX) - } - }; - - // Wait for I/O events. - events.len = wepoll!(epoll_wait( - self.handle, - events.list.as_mut_ptr(), - events.list.len() as c_int, - timeout_ms, - ))? as usize; - log::trace!("new events: handle={:?}, len={}", self.handle, events.len); - - // Break if there was a notification or at least one event, or if deadline is reached. - if self.notified.swap(false, Ordering::SeqCst) || events.len > 0 || timeout_ms == 0 { - break; - } - } - - Ok(()) - } - - /// Sends a notification to wake up the current or next `wait()` call. - pub fn notify(&self) -> io::Result<()> { - log::trace!("notify: handle={:?}", self.handle); - - if self - .notified - .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) - .is_ok() - { - unsafe { - // This call errors if a notification has already been posted, but that's okay - we - // can just ignore the error. - // - // The original wepoll does not support notifications triggered this way, which is - // why wepoll-sys includes a small patch to support them. - windows_sys::Win32::System::IO::PostQueuedCompletionStatus( - self.handle as _, - 0, - 0, - ptr::null_mut(), - ); - } - } - Ok(()) - } - - /// Passes arguments to `epoll_ctl`. - fn ctl(&self, op: u32, sock: RawSocket, ev: Option<(Event, PollMode)>) -> io::Result<()> { - let mut ev = ev - .map(|(ev, mode)| { - let mut flags = match mode { - PollMode::Level => 0, - PollMode::Oneshot => we::EPOLLONESHOT, - PollMode::Edge => { - return Err(crate::unsupported_error( - "edge-triggered events are not supported with wepoll", - )); - } - }; - if ev.readable { - flags |= READ_FLAGS; - } - if ev.writable { - flags |= WRITE_FLAGS; - } - - Ok(we::epoll_event { - events: flags as u32, - data: we::epoll_data { - u64_: ev.key as u64, - }, - }) - }) - .transpose()?; - wepoll!(epoll_ctl( - self.handle, - op as c_int, - sock as we::SOCKET, - ev.as_mut() - .map(|ev| ev as *mut we::epoll_event) - .unwrap_or(ptr::null_mut()), - ))?; - Ok(()) - } -} - -impl AsRawHandle for Poller { - fn as_raw_handle(&self) -> RawHandle { - self.handle as RawHandle - } -} - -#[cfg(not(polling_no_io_safety))] -impl AsHandle for Poller { - fn as_handle(&self) -> BorrowedHandle<'_> { - // SAFETY: lifetime is bound by "self" - unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) } - } -} - -impl Drop for Poller { - fn drop(&mut self) { - log::trace!("drop: handle={:?}", self.handle); - unsafe { - we::epoll_close(self.handle); - } - } -} - -/// Wepoll flags for all possible readability events. -const READ_FLAGS: u32 = we::EPOLLIN | we::EPOLLRDHUP | we::EPOLLHUP | we::EPOLLERR | we::EPOLLPRI; - -/// Wepoll flags for all possible writability events. -const WRITE_FLAGS: u32 = we::EPOLLOUT | we::EPOLLHUP | we::EPOLLERR; - -/// A list of reported I/O events. -pub struct Events { - list: Box<[we::epoll_event; 1024]>, - len: usize, -} - -unsafe impl Send for Events {} - -impl Events { - /// Creates an empty list. - pub fn new() -> Events { - let ev = we::epoll_event { - events: 0, - data: we::epoll_data { u64_: 0 }, - }; - let list = Box::new([ev; 1024]); - Events { list, len: 0 } - } - - /// Iterates over I/O events. - pub fn iter(&self) -> impl Iterator + '_ { - self.list[..self.len].iter().map(|ev| Event { - key: unsafe { ev.data.u64_ } as usize, - readable: (ev.events & READ_FLAGS) != 0, - writable: (ev.events & WRITE_FLAGS) != 0, - }) - } -} diff --git a/tests/concurrent_modification.rs b/tests/concurrent_modification.rs index 0687ad5..7f31f05 100644 --- a/tests/concurrent_modification.rs +++ b/tests/concurrent_modification.rs @@ -43,7 +43,7 @@ fn concurrent_modify() -> io::Result<()> { Parallel::new() .add(|| { - poller.wait(&mut events, None)?; + poller.wait(&mut events, Some(Duration::from_secs(10)))?; Ok(()) }) .add(|| { diff --git a/tests/io.rs b/tests/io.rs new file mode 100644 index 0000000..ab0c8a8 --- /dev/null +++ b/tests/io.rs @@ -0,0 +1,38 @@ +use polling::{Event, Poller}; +use std::io::{self, Write}; +use std::net::{TcpListener, TcpStream}; +use std::time::Duration; + +#[test] +fn basic_io() { + let poller = Poller::new().unwrap(); + let (read, mut write) = tcp_pair().unwrap(); + poller.add(&read, Event::readable(1)).unwrap(); + + // Nothing should be available at first. + let mut events = vec![]; + assert_eq!( + poller + .wait(&mut events, Some(Duration::from_secs(0))) + .unwrap(), + 0 + ); + assert!(events.is_empty()); + + // After a write, the event should be available now. + write.write_all(&[1]).unwrap(); + assert_eq!( + poller + .wait(&mut events, Some(Duration::from_secs(1))) + .unwrap(), + 1 + ); + assert_eq!(&*events, &[Event::readable(1)]); +} + +fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> { + let listener = TcpListener::bind("127.0.0.1:0")?; + let a = TcpStream::connect(listener.local_addr()?)?; + let (b, _) = listener.accept()?; + Ok((a, b)) +} diff --git a/tests/precision.rs b/tests/precision.rs index d29bbce..de5d605 100644 --- a/tests/precision.rs +++ b/tests/precision.rs @@ -18,7 +18,7 @@ fn below_ms() -> io::Result<()> { let elapsed = now.elapsed(); assert_eq!(n, 0); - assert!(elapsed >= dur); + assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur); lowest = lowest.min(elapsed); } @@ -54,7 +54,7 @@ fn above_ms() -> io::Result<()> { let elapsed = now.elapsed(); assert_eq!(n, 0); - assert!(elapsed >= dur); + assert!(elapsed >= dur, "{:?} < {:?}", elapsed, dur); lowest = lowest.min(elapsed); }