bugfix: Manage sources being inserted into kqueue

Thus far, our kqueue implementation has been a relatively thin layer on
top of the OS kqueue. However, kqueue doesn't keep track of when the
same source is inserted twice, or when a source that doesn't exist is
removed. In the interest of keeping consistent behavior between backends
this commit adds a system for tracking when sources are inserted.

Closes #151

Signed-off-by: John Nunley <dev@notgull.net>
This commit is contained in:
John Nunley 2023-09-27 21:30:46 -07:00 committed by GitHub
parent 45ebe3b904
commit 9e143a38e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 151 additions and 3 deletions

View File

@ -1,7 +1,9 @@
//! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD).
use std::collections::HashSet;
use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use std::sync::RwLock;
use std::time::Duration;
use rustix::event::kqueue;
@ -15,6 +17,11 @@ pub struct Poller {
/// File descriptor for the kqueue instance.
kqueue_fd: OwnedFd,
/// List of sources currently registered in this poller.
///
/// This is used to make sure the same source is not registered twice.
sources: RwLock<HashSet<SourceId>>,
/// Notification pipe for waking up the poller.
///
/// On platforms that support `EVFILT_USER`, this uses that to wake up the poller. Otherwise, it
@ -22,6 +29,23 @@ pub struct Poller {
notify: notify::Notify,
}
/// Identifier for a source.
#[doc(hidden)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum SourceId {
/// Registered file descriptor.
Fd(RawFd),
/// Signal.
Signal(std::os::raw::c_int),
/// Process ID.
Pid(rustix::process::Pid),
/// Timer ID.
Timer(usize),
}
impl Poller {
/// Creates a new poller.
pub fn new() -> io::Result<Poller> {
@ -31,6 +55,7 @@ impl Poller {
let poller = Poller {
kqueue_fd,
sources: RwLock::new(HashSet::new()),
notify: notify::Notify::new()?,
};
@ -60,6 +85,8 @@ impl Poller {
///
/// The file descriptor must be valid and it must last until it is deleted.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
self.add_source(SourceId::Fd(fd))?;
// File descriptors don't need to be added explicitly, so just modify the interest.
self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
}
@ -79,6 +106,8 @@ impl Poller {
};
let _enter = span.as_ref().map(|s| s.enter());
self.has_source(SourceId::Fd(fd.as_raw_fd()))?;
let mode_flags = mode_to_flags(mode);
let read_flags = if ev.readable {
@ -143,10 +172,57 @@ impl Poller {
Ok(())
}
/// Add a source to the sources set.
#[inline]
pub(crate) fn add_source(&self, source: SourceId) -> io::Result<()> {
if self
.sources
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(source)
{
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::AlreadyExists))
}
}
/// Tell if a source is currently inside the set.
#[inline]
pub(crate) fn has_source(&self, source: SourceId) -> io::Result<()> {
if self
.sources
.read()
.unwrap_or_else(|e| e.into_inner())
.contains(&source)
{
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
}
/// Remove a source from the sources set.
#[inline]
pub(crate) fn remove_source(&self, source: SourceId) -> io::Result<()> {
if self
.sources
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(&source)
{
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
}
/// Deletes a file descriptor.
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
// Simply delete interest in the file descriptor.
self.modify(fd, Event::none(0), PollMode::Oneshot)
self.modify(fd, Event::none(0), PollMode::Oneshot)?;
self.remove_source(SourceId::Fd(fd.as_raw_fd()))
}
/// Waits for I/O events with an optional timeout.

View File

@ -1,6 +1,6 @@
//! Functionality that is only available for `kqueue`-based platforms.
use crate::sys::mode_to_flags;
use crate::sys::{mode_to_flags, SourceId};
use crate::{PollMode, Poller};
use std::io;
@ -98,10 +98,13 @@ impl<F: Filter> PollerKqueueExt<F> for Poller {
#[inline(always)]
fn add_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
// No difference between adding and modifying in kqueue.
self.poller.add_source(filter.source_id())?;
self.modify_filter(filter, key, mode)
}
fn modify_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
self.poller.has_source(filter.source_id())?;
// Convert the filter into a kevent.
let event = filter.filter(kqueue::EventFlags::ADD | mode_to_flags(mode), key);
@ -114,7 +117,9 @@ impl<F: Filter> PollerKqueueExt<F> for Poller {
let event = filter.filter(kqueue::EventFlags::DELETE, 0);
// Delete the filter.
self.poller.submit_changes([event])
self.poller.submit_changes([event])?;
self.poller.remove_source(filter.source_id())
}
}
@ -126,6 +131,11 @@ unsafe impl<T: FilterSealed + ?Sized> FilterSealed for &T {
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event {
(**self).filter(flags, key)
}
#[inline(always)]
fn source_id(&self) -> SourceId {
(**self).source_id()
}
}
impl<T: Filter + ?Sized> Filter for &T {}
@ -149,6 +159,11 @@ unsafe impl FilterSealed for Signal {
key as _,
)
}
#[inline(always)]
fn source_id(&self) -> SourceId {
SourceId::Signal(self.0)
}
}
impl Filter for Signal {}
@ -207,6 +222,11 @@ unsafe impl FilterSealed for Process<'_> {
key as _,
)
}
#[inline(always)]
fn source_id(&self) -> SourceId {
SourceId::Pid(rustix::process::Pid::from_child(self.child))
}
}
impl Filter for Process<'_> {}
@ -234,11 +254,17 @@ unsafe impl FilterSealed for Timer {
key as _,
)
}
#[inline(always)]
fn source_id(&self) -> SourceId {
SourceId::Timer(self.id)
}
}
impl Filter for Timer {}
mod __private {
use crate::sys::SourceId;
use rustix::event::kqueue;
#[doc(hidden)]
@ -247,5 +273,8 @@ mod __private {
///
/// This filter's flags must have `EV_RECEIPT`.
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event;
/// Get the source ID for this source.
fn source_id(&self) -> SourceId;
}
}

View File

@ -1,6 +1,7 @@
use polling::{Event, Events, Poller};
use std::io::{self, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use std::time::Duration;
#[test]
@ -38,6 +39,48 @@ fn basic_io() {
poller.delete(&read).unwrap();
}
#[test]
fn insert_twice() {
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
let (read, mut write) = tcp_pair().unwrap();
let read = Arc::new(read);
let poller = Poller::new().unwrap();
unsafe {
#[cfg(unix)]
let read = read.as_raw_fd();
#[cfg(windows)]
let read = read.as_raw_socket();
poller.add(read, Event::readable(1)).unwrap();
assert_eq!(
poller.add(read, Event::readable(1)).unwrap_err().kind(),
io::ErrorKind::AlreadyExists
);
}
write.write_all(&[1]).unwrap();
let mut events = Events::new();
assert_eq!(
poller
.wait(&mut events, Some(Duration::from_secs(1)))
.unwrap(),
1
);
assert_eq!(events.len(), 1);
assert_eq!(
events.iter().next().unwrap().with_no_extra(),
Event::readable(1)
);
poller.delete(&read).unwrap();
}
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let a = TcpStream::connect(listener.local_addr()?)?;