//! A safe wrapper around the Windows I/O API. use std::convert::{TryFrom, TryInto}; use std::fmt; use std::io; use std::marker::PhantomData; use std::mem::MaybeUninit; use std::ops::Deref; use std::os::windows::io::{AsRawHandle, RawHandle}; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use windows_sys::Win32::Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE}; use windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes; use windows_sys::Win32::System::WindowsProgramming::{FILE_SKIP_SET_EVENT_ON_HANDLE, INFINITE}; use windows_sys::Win32::System::IO::{ CreateIoCompletionPort, GetQueuedCompletionStatusEx, PostQueuedCompletionStatus, OVERLAPPED, OVERLAPPED_ENTRY, }; /// A completion block which can be used with I/O completion ports. /// /// # Safety /// /// This must be a valid completion block. pub(super) unsafe trait Completion { /// Signal to the completion block that we are about to start an operation. fn try_lock(self: Pin<&Self>) -> bool; /// Unlock the completion block. unsafe fn unlock(self: Pin<&Self>); } /// The pointer to a completion block. /// /// # Safety /// /// This must be a valid completion block. pub(super) unsafe trait CompletionHandle: Deref + Sized { /// Type of the completion block. type Completion: Completion; /// Get a pointer to the completion block. /// /// The pointer is pinned since the underlying object should not be moved /// after creation. This prevents it from being invalidated while it's /// used in an overlapped operation. fn get(&self) -> Pin<&Self::Completion>; /// Convert this block into a pointer that can be passed as `*mut OVERLAPPED`. fn into_ptr(this: Self) -> *mut OVERLAPPED; /// Convert a pointer that was passed as `*mut OVERLAPPED` into a pointer to this block. /// /// # Safety /// /// This must be a valid pointer to a completion block. unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self; /// Convert to a pointer without losing ownership. fn as_ptr(&self) -> *mut OVERLAPPED; } unsafe impl<'a, T: Completion> CompletionHandle for Pin<&'a T> { type Completion = T; fn get(&self) -> Pin<&Self::Completion> { *self } fn into_ptr(this: Self) -> *mut OVERLAPPED { unsafe { Pin::into_inner_unchecked(this) as *const T as *mut OVERLAPPED } } unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self { Pin::new_unchecked(&*(ptr as *const T)) } fn as_ptr(&self) -> *mut OVERLAPPED { self.get_ref() as *const T as *mut OVERLAPPED } } unsafe impl CompletionHandle for Pin> { type Completion = T; fn get(&self) -> Pin<&Self::Completion> { self.as_ref() } fn into_ptr(this: Self) -> *mut OVERLAPPED { unsafe { Arc::into_raw(Pin::into_inner_unchecked(this)) as *const T as *mut OVERLAPPED } } unsafe fn from_ptr(ptr: *mut OVERLAPPED) -> Self { Pin::new_unchecked(Arc::from_raw(ptr as *const T)) } fn as_ptr(&self) -> *mut OVERLAPPED { self.as_ref().get_ref() as *const T as *mut OVERLAPPED } } /// A handle to the I/O completion port. pub(super) struct IoCompletionPort { /// The underlying handle. handle: HANDLE, /// We own the status block. _marker: PhantomData, } impl Drop for IoCompletionPort { fn drop(&mut self) { unsafe { CloseHandle(self.handle); } } } impl AsRawHandle for IoCompletionPort { fn as_raw_handle(&self) -> RawHandle { self.handle as _ } } impl fmt::Debug for IoCompletionPort { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { struct WriteAsHex(HANDLE); impl fmt::Debug for WriteAsHex { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:010x}", self.0) } } f.debug_struct("IoCompletionPort") .field("handle", &WriteAsHex(self.handle)) .finish() } } impl IoCompletionPort { /// Create a new I/O completion port. pub(super) fn new(threads: usize) -> io::Result { let handle = unsafe { CreateIoCompletionPort( INVALID_HANDLE_VALUE, 0, 0, threads.try_into().expect("too many threads"), ) }; if handle == 0 { Err(io::Error::last_os_error()) } else { Ok(Self { handle, _marker: PhantomData, }) } } /// Register a handle with this I/O completion port. pub(super) fn register( &self, handle: &impl AsRawHandle, // TODO change to AsHandle skip_set_event_on_handle: bool, ) -> io::Result<()> { let handle = handle.as_raw_handle(); let result = unsafe { CreateIoCompletionPort(handle as _, self.handle, handle as usize, 0) }; if result == 0 { return Err(io::Error::last_os_error()); } if skip_set_event_on_handle { // Set the skip event on handle. let result = unsafe { SetFileCompletionNotificationModes(handle as _, FILE_SKIP_SET_EVENT_ON_HANDLE as _) }; if result == 0 { return Err(io::Error::last_os_error()); } } Ok(()) } /// Post a completion packet to this port. pub(super) fn post(&self, bytes_transferred: usize, id: usize, packet: T) -> io::Result<()> { let result = unsafe { PostQueuedCompletionStatus( self.handle, bytes_transferred .try_into() .expect("too many bytes transferred"), id, T::into_ptr(packet), ) }; if result == 0 { Err(io::Error::last_os_error()) } else { Ok(()) } } /// Wait for completion packets to arrive. pub(super) fn wait( &self, packets: &mut Vec>, timeout: Option, ) -> io::Result { // Drop the current packets. packets.clear(); let mut count = MaybeUninit::::uninit(); let timeout = timeout.map_or(INFINITE, dur2timeout); let result = unsafe { GetQueuedCompletionStatusEx( self.handle, packets.as_mut_ptr() as _, packets.capacity().try_into().expect("too many packets"), count.as_mut_ptr(), timeout, 0, ) }; if result == 0 { let io_error = io::Error::last_os_error(); if io_error.kind() == io::ErrorKind::TimedOut { Ok(0) } else { Err(io_error) } } else { let count = unsafe { count.assume_init() }; unsafe { packets.set_len(count as _); } Ok(count as _) } } } /// An `OVERLAPPED_ENTRY` resulting from an I/O completion port. #[repr(transparent)] pub(super) struct OverlappedEntry { /// The underlying entry. entry: OVERLAPPED_ENTRY, /// We own the status block. _marker: PhantomData, } impl fmt::Debug for OverlappedEntry { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("OverlappedEntry { .. }") } } impl OverlappedEntry { /// Convert into the completion packet. pub(super) fn into_packet(self) -> T { let packet = unsafe { self.packet() }; std::mem::forget(self); packet } /// Get the packet reference that this entry refers to. /// /// # Safety /// /// This function should only be called once, since it moves /// out the `T` from the `OVERLAPPED_ENTRY`. unsafe fn packet(&self) -> T { let packet = T::from_ptr(self.entry.lpOverlapped); packet.get().unlock(); packet } } impl Drop for OverlappedEntry { fn drop(&mut self) { drop(unsafe { self.packet() }); } } // Implementation taken from https://github.com/rust-lang/rust/blob/db5476571d9b27c862b95c1e64764b0ac8980e23/src/libstd/sys/windows/mod.rs fn dur2timeout(dur: Duration) -> u32 { // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the // timeouts in windows APIs are typically u32 milliseconds. To translate, we // have two pieces to take care of: // // * Nanosecond precision is rounded up // * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE // (never time out). dur.as_secs() .checked_mul(1000) .and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000)) .and_then(|ms| { if dur.subsec_nanos() % 1_000_000 > 0 { ms.checked_add(1) } else { Some(ms) } }) .and_then(|x| u32::try_from(x).ok()) .unwrap_or(INFINITE) } struct CallOnDrop(F); impl Drop for CallOnDrop { fn drop(&mut self) { (self.0)(); } }