mirror of https://github.com/smol-rs/polling
bugfix: Manage sources being inserted into kqueue
Thus far, our kqueue implementation has been a relatively thin layer on top of the OS kqueue. However, kqueue doesn't keep track of when the same source is inserted twice, or when a source that doesn't exist is removed. In the interest of keeping consistent behavior between backends this commit adds a system for tracking when sources are inserted. Closes #151 Signed-off-by: John Nunley <dev@notgull.net>
This commit is contained in:
parent
45ebe3b904
commit
9e143a38e1
|
@ -1,7 +1,9 @@
|
|||
//! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD).
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::io;
|
||||
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use rustix::event::kqueue;
|
||||
|
@ -15,6 +17,11 @@ pub struct Poller {
|
|||
/// File descriptor for the kqueue instance.
|
||||
kqueue_fd: OwnedFd,
|
||||
|
||||
/// List of sources currently registered in this poller.
|
||||
///
|
||||
/// This is used to make sure the same source is not registered twice.
|
||||
sources: RwLock<HashSet<SourceId>>,
|
||||
|
||||
/// Notification pipe for waking up the poller.
|
||||
///
|
||||
/// On platforms that support `EVFILT_USER`, this uses that to wake up the poller. Otherwise, it
|
||||
|
@ -22,6 +29,23 @@ pub struct Poller {
|
|||
notify: notify::Notify,
|
||||
}
|
||||
|
||||
/// Identifier for a source.
|
||||
#[doc(hidden)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum SourceId {
|
||||
/// Registered file descriptor.
|
||||
Fd(RawFd),
|
||||
|
||||
/// Signal.
|
||||
Signal(std::os::raw::c_int),
|
||||
|
||||
/// Process ID.
|
||||
Pid(rustix::process::Pid),
|
||||
|
||||
/// Timer ID.
|
||||
Timer(usize),
|
||||
}
|
||||
|
||||
impl Poller {
|
||||
/// Creates a new poller.
|
||||
pub fn new() -> io::Result<Poller> {
|
||||
|
@ -31,6 +55,7 @@ impl Poller {
|
|||
|
||||
let poller = Poller {
|
||||
kqueue_fd,
|
||||
sources: RwLock::new(HashSet::new()),
|
||||
notify: notify::Notify::new()?,
|
||||
};
|
||||
|
||||
|
@ -60,6 +85,8 @@ impl Poller {
|
|||
///
|
||||
/// The file descriptor must be valid and it must last until it is deleted.
|
||||
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
|
||||
self.add_source(SourceId::Fd(fd))?;
|
||||
|
||||
// File descriptors don't need to be added explicitly, so just modify the interest.
|
||||
self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
|
||||
}
|
||||
|
@ -79,6 +106,8 @@ impl Poller {
|
|||
};
|
||||
let _enter = span.as_ref().map(|s| s.enter());
|
||||
|
||||
self.has_source(SourceId::Fd(fd.as_raw_fd()))?;
|
||||
|
||||
let mode_flags = mode_to_flags(mode);
|
||||
|
||||
let read_flags = if ev.readable {
|
||||
|
@ -143,10 +172,57 @@ impl Poller {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a source to the sources set.
|
||||
#[inline]
|
||||
pub(crate) fn add_source(&self, source: SourceId) -> io::Result<()> {
|
||||
if self
|
||||
.sources
|
||||
.write()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.insert(source)
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
Err(io::Error::from(io::ErrorKind::AlreadyExists))
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell if a source is currently inside the set.
|
||||
#[inline]
|
||||
pub(crate) fn has_source(&self, source: SourceId) -> io::Result<()> {
|
||||
if self
|
||||
.sources
|
||||
.read()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.contains(&source)
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
Err(io::Error::from(io::ErrorKind::NotFound))
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a source from the sources set.
|
||||
#[inline]
|
||||
pub(crate) fn remove_source(&self, source: SourceId) -> io::Result<()> {
|
||||
if self
|
||||
.sources
|
||||
.write()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.remove(&source)
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
Err(io::Error::from(io::ErrorKind::NotFound))
|
||||
}
|
||||
}
|
||||
|
||||
/// Deletes a file descriptor.
|
||||
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
|
||||
// Simply delete interest in the file descriptor.
|
||||
self.modify(fd, Event::none(0), PollMode::Oneshot)
|
||||
self.modify(fd, Event::none(0), PollMode::Oneshot)?;
|
||||
|
||||
self.remove_source(SourceId::Fd(fd.as_raw_fd()))
|
||||
}
|
||||
|
||||
/// Waits for I/O events with an optional timeout.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
//! Functionality that is only available for `kqueue`-based platforms.
|
||||
|
||||
use crate::sys::mode_to_flags;
|
||||
use crate::sys::{mode_to_flags, SourceId};
|
||||
use crate::{PollMode, Poller};
|
||||
|
||||
use std::io;
|
||||
|
@ -98,10 +98,13 @@ impl<F: Filter> PollerKqueueExt<F> for Poller {
|
|||
#[inline(always)]
|
||||
fn add_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
|
||||
// No difference between adding and modifying in kqueue.
|
||||
self.poller.add_source(filter.source_id())?;
|
||||
self.modify_filter(filter, key, mode)
|
||||
}
|
||||
|
||||
fn modify_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
|
||||
self.poller.has_source(filter.source_id())?;
|
||||
|
||||
// Convert the filter into a kevent.
|
||||
let event = filter.filter(kqueue::EventFlags::ADD | mode_to_flags(mode), key);
|
||||
|
||||
|
@ -114,7 +117,9 @@ impl<F: Filter> PollerKqueueExt<F> for Poller {
|
|||
let event = filter.filter(kqueue::EventFlags::DELETE, 0);
|
||||
|
||||
// Delete the filter.
|
||||
self.poller.submit_changes([event])
|
||||
self.poller.submit_changes([event])?;
|
||||
|
||||
self.poller.remove_source(filter.source_id())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,6 +131,11 @@ unsafe impl<T: FilterSealed + ?Sized> FilterSealed for &T {
|
|||
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event {
|
||||
(**self).filter(flags, key)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn source_id(&self) -> SourceId {
|
||||
(**self).source_id()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Filter + ?Sized> Filter for &T {}
|
||||
|
@ -149,6 +159,11 @@ unsafe impl FilterSealed for Signal {
|
|||
key as _,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn source_id(&self) -> SourceId {
|
||||
SourceId::Signal(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Filter for Signal {}
|
||||
|
@ -207,6 +222,11 @@ unsafe impl FilterSealed for Process<'_> {
|
|||
key as _,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn source_id(&self) -> SourceId {
|
||||
SourceId::Pid(rustix::process::Pid::from_child(self.child))
|
||||
}
|
||||
}
|
||||
|
||||
impl Filter for Process<'_> {}
|
||||
|
@ -234,11 +254,17 @@ unsafe impl FilterSealed for Timer {
|
|||
key as _,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn source_id(&self) -> SourceId {
|
||||
SourceId::Timer(self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Filter for Timer {}
|
||||
|
||||
mod __private {
|
||||
use crate::sys::SourceId;
|
||||
use rustix::event::kqueue;
|
||||
|
||||
#[doc(hidden)]
|
||||
|
@ -247,5 +273,8 @@ mod __private {
|
|||
///
|
||||
/// This filter's flags must have `EV_RECEIPT`.
|
||||
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event;
|
||||
|
||||
/// Get the source ID for this source.
|
||||
fn source_id(&self) -> SourceId;
|
||||
}
|
||||
}
|
||||
|
|
43
tests/io.rs
43
tests/io.rs
|
@ -1,6 +1,7 @@
|
|||
use polling::{Event, Events, Poller};
|
||||
use std::io::{self, Write};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
|
@ -38,6 +39,48 @@ fn basic_io() {
|
|||
poller.delete(&read).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insert_twice() {
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::io::AsRawFd;
|
||||
#[cfg(windows)]
|
||||
use std::os::windows::io::AsRawSocket;
|
||||
|
||||
let (read, mut write) = tcp_pair().unwrap();
|
||||
let read = Arc::new(read);
|
||||
|
||||
let poller = Poller::new().unwrap();
|
||||
unsafe {
|
||||
#[cfg(unix)]
|
||||
let read = read.as_raw_fd();
|
||||
#[cfg(windows)]
|
||||
let read = read.as_raw_socket();
|
||||
|
||||
poller.add(read, Event::readable(1)).unwrap();
|
||||
assert_eq!(
|
||||
poller.add(read, Event::readable(1)).unwrap_err().kind(),
|
||||
io::ErrorKind::AlreadyExists
|
||||
);
|
||||
}
|
||||
|
||||
write.write_all(&[1]).unwrap();
|
||||
let mut events = Events::new();
|
||||
assert_eq!(
|
||||
poller
|
||||
.wait(&mut events, Some(Duration::from_secs(1)))
|
||||
.unwrap(),
|
||||
1
|
||||
);
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
assert_eq!(
|
||||
events.iter().next().unwrap().with_no_extra(),
|
||||
Event::readable(1)
|
||||
);
|
||||
|
||||
poller.delete(&read).unwrap();
|
||||
}
|
||||
|
||||
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let a = TcpStream::connect(listener.local_addr()?)?;
|
||||
|
|
Loading…
Reference in New Issue