diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index 4ada533..447c466 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -323,6 +323,11 @@ impl Poller { self.port.post(0, 0, self.notifier.clone()) } + /// Push an IOCP packet into the queue. + pub(super) fn post(&self, packet: CompletionPacket) -> io::Result<()> { + self.port.post(0, 0, packet.0) + } + /// Run an update on a packet. fn update_packet(&self, mut packet: Packet) -> io::Result<()> { loop { @@ -443,6 +448,27 @@ impl Events { } } +/// A packet used to wake up the poller with an event. +#[derive(Debug, Clone)] +pub struct CompletionPacket(Packet); + +impl CompletionPacket { + /// Create a new completion packet with a custom event. + pub fn new(event: Event) -> Self { + Self(Arc::pin(IoStatusBlock::from(PacketInner::Custom { event }))) + } + + /// Get the event associated with this packet. + pub fn event(&self) -> &Event { + let data = self.0.as_ref().data().project_ref(); + + match data { + PacketInnerProj::Custom { event } => event, + _ => unreachable!(), + } + } +} + /// The type of our completion packet. type Packet = Pin>; type PacketUnwrapped = IoStatusBlock; @@ -462,6 +488,11 @@ pin_project! { socket: Mutex }, + /// A custom event sent by the user. + Custom { + event: Event, + }, + // A packet used to wake up the poller. Wakeup { #[pin] _pinned: PhantomPinned }, } @@ -471,6 +502,7 @@ impl fmt::Debug for PacketInner { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Wakeup { .. } => f.write_str("Wakeup { .. }"), + Self::Custom { event } => f.debug_struct("Custom").field("event", event).finish(), Self::Socket { socket, .. } => f .debug_struct("Socket") .field("packet", &"..") @@ -484,7 +516,7 @@ impl HasAfdInfo for PacketInner { fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell> { match self.project_ref() { PacketInnerProj::Socket { packet, .. } => packet, - PacketInnerProj::Wakeup { .. } => unreachable!(), + _ => unreachable!(), } } } @@ -591,6 +623,10 @@ impl PacketUnwrapped { let (afd_info, socket) = match inner { PacketInnerProj::Socket { packet, socket } => (packet, socket), + PacketInnerProj::Custom { event } => { + // This is a custom event. + return Ok(FeedEventResult::Event(*event)); + } PacketInnerProj::Wakeup { .. } => { // The poller was notified. return Ok(FeedEventResult::Notified); @@ -712,8 +748,8 @@ impl PacketUnwrapped { let inner = self.data().project_ref(); let state = match inner { - PacketInnerProj::Wakeup { .. } => return None, PacketInnerProj::Socket { socket, .. } => socket, + _ => return None, }; Some(lock!(state.lock())) diff --git a/src/os.rs b/src/os.rs index 280adf3..2a5d6e6 100644 --- a/src/os.rs +++ b/src/os.rs @@ -15,7 +15,12 @@ ))] pub mod kqueue; +#[cfg(target_os = "windows")] +pub mod iocp; + mod __private { #[doc(hidden)] pub trait PollerSealed {} + + impl PollerSealed for crate::Poller {} } diff --git a/src/os/iocp.rs b/src/os/iocp.rs new file mode 100644 index 0000000..edaf00f --- /dev/null +++ b/src/os/iocp.rs @@ -0,0 +1,52 @@ +//! Functionality that is only availale for IOCP-based platforms. + +pub use crate::sys::CompletionPacket; + +use super::__private::PollerSealed; +use crate::Poller; +use std::io; + +/// Extension trait for the [`Poller`] type that provides functionality specific to IOCP-based +/// platforms. +/// +/// [`Poller`]: crate::Poller +pub trait PollerIocpExt: PollerSealed { + /// Post a new [`Event`] to the poller. + /// + /// # Examples + /// + /// ```rust + /// use polling::{Poller, Event}; + /// use polling::os::iocp::{CompletionPacket, PollerIocpExt}; + /// + /// use std::thread; + /// use std::sync::Arc; + /// use std::time::Duration; + /// + /// # fn main() -> std::io::Result<()> { + /// // Spawn a thread to wake us up after 100ms. + /// let poller = Arc::new(Poller::new()?); + /// thread::spawn({ + /// let poller = poller.clone(); + /// move || { + /// let packet = CompletionPacket::new(Event::readable(0)); + /// thread::sleep(Duration::from_millis(100)); + /// poller.post(packet).unwrap(); + /// } + /// }); + /// + /// // Wait for the event. + /// let mut events = vec![]; + /// poller.wait(&mut events, None)?; + /// + /// assert_eq!(events.len(), 1); + /// # Ok(()) } + /// ``` + fn post(&self, packet: CompletionPacket) -> io::Result<()>; +} + +impl PollerIocpExt for Poller { + fn post(&self, packet: CompletionPacket) -> io::Result<()> { + self.poller.post(packet) + } +} diff --git a/src/os/kqueue.rs b/src/os/kqueue.rs index ad6aa1d..9b399da 100644 --- a/src/os/kqueue.rs +++ b/src/os/kqueue.rs @@ -93,8 +93,6 @@ pub trait PollerKqueueExt: PollerSealed { fn delete_filter(&self, filter: F) -> io::Result<()>; } -impl PollerSealed for Poller {} - impl PollerKqueueExt for Poller { #[inline(always)] fn add_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> { diff --git a/tests/windows_post.rs b/tests/windows_post.rs new file mode 100644 index 0000000..488fab3 --- /dev/null +++ b/tests/windows_post.rs @@ -0,0 +1,57 @@ +//! Tests for the post() function on Windows. + +#![cfg(windows)] + +use polling::os::iocp::{CompletionPacket, PollerIocpExt}; +use polling::{Event, Poller}; + +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +#[test] +fn post_smoke() { + let poller = Poller::new().unwrap(); + let mut events = Vec::new(); + + poller + .post(CompletionPacket::new(Event::readable(1))) + .unwrap(); + poller.wait(&mut events, None).unwrap(); + + assert_eq!(events.len(), 1); + assert_eq!(events[0], Event::readable(1)); +} + +#[test] +fn post_multithread() { + let poller = Arc::new(Poller::new().unwrap()); + let mut events = Vec::new(); + + thread::spawn({ + let poller = Arc::clone(&poller); + move || { + for i in 0..3 { + poller + .post(CompletionPacket::new(Event::writable(i))) + .unwrap(); + + thread::sleep(Duration::from_millis(100)); + } + } + }); + + for i in 0..3 { + poller + .wait(&mut events, Some(Duration::from_secs(5))) + .unwrap(); + + assert_eq!(events.len(), 1); + assert_eq!(events.pop(), Some(Event::writable(i))); + } + + poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(events.len(), 0); +}