mirror of https://github.com/smol-rs/polling
bugfix: Manage sources being inserted into kqueue
Thus far, our kqueue implementation has been a relatively thin layer on top of the OS kqueue. However, kqueue doesn't keep track of when the same source is inserted twice, or when a source that doesn't exist is removed. In the interest of keeping consistent behavior between backends this commit adds a system for tracking when sources are inserted. Closes #151 Signed-off-by: John Nunley <dev@notgull.net>
This commit is contained in:
parent
45ebe3b904
commit
9e143a38e1
|
@ -1,7 +1,9 @@
|
||||||
//! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD).
|
//! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD).
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
|
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
|
||||||
|
use std::sync::RwLock;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use rustix::event::kqueue;
|
use rustix::event::kqueue;
|
||||||
|
@ -15,6 +17,11 @@ pub struct Poller {
|
||||||
/// File descriptor for the kqueue instance.
|
/// File descriptor for the kqueue instance.
|
||||||
kqueue_fd: OwnedFd,
|
kqueue_fd: OwnedFd,
|
||||||
|
|
||||||
|
/// List of sources currently registered in this poller.
|
||||||
|
///
|
||||||
|
/// This is used to make sure the same source is not registered twice.
|
||||||
|
sources: RwLock<HashSet<SourceId>>,
|
||||||
|
|
||||||
/// Notification pipe for waking up the poller.
|
/// Notification pipe for waking up the poller.
|
||||||
///
|
///
|
||||||
/// On platforms that support `EVFILT_USER`, this uses that to wake up the poller. Otherwise, it
|
/// On platforms that support `EVFILT_USER`, this uses that to wake up the poller. Otherwise, it
|
||||||
|
@ -22,6 +29,23 @@ pub struct Poller {
|
||||||
notify: notify::Notify,
|
notify: notify::Notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Identifier for a source.
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum SourceId {
|
||||||
|
/// Registered file descriptor.
|
||||||
|
Fd(RawFd),
|
||||||
|
|
||||||
|
/// Signal.
|
||||||
|
Signal(std::os::raw::c_int),
|
||||||
|
|
||||||
|
/// Process ID.
|
||||||
|
Pid(rustix::process::Pid),
|
||||||
|
|
||||||
|
/// Timer ID.
|
||||||
|
Timer(usize),
|
||||||
|
}
|
||||||
|
|
||||||
impl Poller {
|
impl Poller {
|
||||||
/// Creates a new poller.
|
/// Creates a new poller.
|
||||||
pub fn new() -> io::Result<Poller> {
|
pub fn new() -> io::Result<Poller> {
|
||||||
|
@ -31,6 +55,7 @@ impl Poller {
|
||||||
|
|
||||||
let poller = Poller {
|
let poller = Poller {
|
||||||
kqueue_fd,
|
kqueue_fd,
|
||||||
|
sources: RwLock::new(HashSet::new()),
|
||||||
notify: notify::Notify::new()?,
|
notify: notify::Notify::new()?,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -60,6 +85,8 @@ impl Poller {
|
||||||
///
|
///
|
||||||
/// The file descriptor must be valid and it must last until it is deleted.
|
/// The file descriptor must be valid and it must last until it is deleted.
|
||||||
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
|
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
|
||||||
|
self.add_source(SourceId::Fd(fd))?;
|
||||||
|
|
||||||
// File descriptors don't need to be added explicitly, so just modify the interest.
|
// File descriptors don't need to be added explicitly, so just modify the interest.
|
||||||
self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
|
self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
|
||||||
}
|
}
|
||||||
|
@ -79,6 +106,8 @@ impl Poller {
|
||||||
};
|
};
|
||||||
let _enter = span.as_ref().map(|s| s.enter());
|
let _enter = span.as_ref().map(|s| s.enter());
|
||||||
|
|
||||||
|
self.has_source(SourceId::Fd(fd.as_raw_fd()))?;
|
||||||
|
|
||||||
let mode_flags = mode_to_flags(mode);
|
let mode_flags = mode_to_flags(mode);
|
||||||
|
|
||||||
let read_flags = if ev.readable {
|
let read_flags = if ev.readable {
|
||||||
|
@ -143,10 +172,57 @@ impl Poller {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add a source to the sources set.
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn add_source(&self, source: SourceId) -> io::Result<()> {
|
||||||
|
if self
|
||||||
|
.sources
|
||||||
|
.write()
|
||||||
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
|
.insert(source)
|
||||||
|
{
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(io::Error::from(io::ErrorKind::AlreadyExists))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tell if a source is currently inside the set.
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn has_source(&self, source: SourceId) -> io::Result<()> {
|
||||||
|
if self
|
||||||
|
.sources
|
||||||
|
.read()
|
||||||
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
|
.contains(&source)
|
||||||
|
{
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(io::Error::from(io::ErrorKind::NotFound))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a source from the sources set.
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn remove_source(&self, source: SourceId) -> io::Result<()> {
|
||||||
|
if self
|
||||||
|
.sources
|
||||||
|
.write()
|
||||||
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
|
.remove(&source)
|
||||||
|
{
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(io::Error::from(io::ErrorKind::NotFound))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Deletes a file descriptor.
|
/// Deletes a file descriptor.
|
||||||
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
|
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
|
||||||
// Simply delete interest in the file descriptor.
|
// Simply delete interest in the file descriptor.
|
||||||
self.modify(fd, Event::none(0), PollMode::Oneshot)
|
self.modify(fd, Event::none(0), PollMode::Oneshot)?;
|
||||||
|
|
||||||
|
self.remove_source(SourceId::Fd(fd.as_raw_fd()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Waits for I/O events with an optional timeout.
|
/// Waits for I/O events with an optional timeout.
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
//! Functionality that is only available for `kqueue`-based platforms.
|
//! Functionality that is only available for `kqueue`-based platforms.
|
||||||
|
|
||||||
use crate::sys::mode_to_flags;
|
use crate::sys::{mode_to_flags, SourceId};
|
||||||
use crate::{PollMode, Poller};
|
use crate::{PollMode, Poller};
|
||||||
|
|
||||||
use std::io;
|
use std::io;
|
||||||
|
@ -98,10 +98,13 @@ impl<F: Filter> PollerKqueueExt<F> for Poller {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn add_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
|
fn add_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
|
||||||
// No difference between adding and modifying in kqueue.
|
// No difference between adding and modifying in kqueue.
|
||||||
|
self.poller.add_source(filter.source_id())?;
|
||||||
self.modify_filter(filter, key, mode)
|
self.modify_filter(filter, key, mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn modify_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
|
fn modify_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
|
||||||
|
self.poller.has_source(filter.source_id())?;
|
||||||
|
|
||||||
// Convert the filter into a kevent.
|
// Convert the filter into a kevent.
|
||||||
let event = filter.filter(kqueue::EventFlags::ADD | mode_to_flags(mode), key);
|
let event = filter.filter(kqueue::EventFlags::ADD | mode_to_flags(mode), key);
|
||||||
|
|
||||||
|
@ -114,7 +117,9 @@ impl<F: Filter> PollerKqueueExt<F> for Poller {
|
||||||
let event = filter.filter(kqueue::EventFlags::DELETE, 0);
|
let event = filter.filter(kqueue::EventFlags::DELETE, 0);
|
||||||
|
|
||||||
// Delete the filter.
|
// Delete the filter.
|
||||||
self.poller.submit_changes([event])
|
self.poller.submit_changes([event])?;
|
||||||
|
|
||||||
|
self.poller.remove_source(filter.source_id())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,6 +131,11 @@ unsafe impl<T: FilterSealed + ?Sized> FilterSealed for &T {
|
||||||
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event {
|
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event {
|
||||||
(**self).filter(flags, key)
|
(**self).filter(flags, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn source_id(&self) -> SourceId {
|
||||||
|
(**self).source_id()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Filter + ?Sized> Filter for &T {}
|
impl<T: Filter + ?Sized> Filter for &T {}
|
||||||
|
@ -149,6 +159,11 @@ unsafe impl FilterSealed for Signal {
|
||||||
key as _,
|
key as _,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn source_id(&self) -> SourceId {
|
||||||
|
SourceId::Signal(self.0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Filter for Signal {}
|
impl Filter for Signal {}
|
||||||
|
@ -207,6 +222,11 @@ unsafe impl FilterSealed for Process<'_> {
|
||||||
key as _,
|
key as _,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn source_id(&self) -> SourceId {
|
||||||
|
SourceId::Pid(rustix::process::Pid::from_child(self.child))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Filter for Process<'_> {}
|
impl Filter for Process<'_> {}
|
||||||
|
@ -234,11 +254,17 @@ unsafe impl FilterSealed for Timer {
|
||||||
key as _,
|
key as _,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn source_id(&self) -> SourceId {
|
||||||
|
SourceId::Timer(self.id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Filter for Timer {}
|
impl Filter for Timer {}
|
||||||
|
|
||||||
mod __private {
|
mod __private {
|
||||||
|
use crate::sys::SourceId;
|
||||||
use rustix::event::kqueue;
|
use rustix::event::kqueue;
|
||||||
|
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
|
@ -247,5 +273,8 @@ mod __private {
|
||||||
///
|
///
|
||||||
/// This filter's flags must have `EV_RECEIPT`.
|
/// This filter's flags must have `EV_RECEIPT`.
|
||||||
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event;
|
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event;
|
||||||
|
|
||||||
|
/// Get the source ID for this source.
|
||||||
|
fn source_id(&self) -> SourceId;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
43
tests/io.rs
43
tests/io.rs
|
@ -1,6 +1,7 @@
|
||||||
use polling::{Event, Events, Poller};
|
use polling::{Event, Events, Poller};
|
||||||
use std::io::{self, Write};
|
use std::io::{self, Write};
|
||||||
use std::net::{TcpListener, TcpStream};
|
use std::net::{TcpListener, TcpStream};
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -38,6 +39,48 @@ fn basic_io() {
|
||||||
poller.delete(&read).unwrap();
|
poller.delete(&read).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn insert_twice() {
|
||||||
|
#[cfg(unix)]
|
||||||
|
use std::os::unix::io::AsRawFd;
|
||||||
|
#[cfg(windows)]
|
||||||
|
use std::os::windows::io::AsRawSocket;
|
||||||
|
|
||||||
|
let (read, mut write) = tcp_pair().unwrap();
|
||||||
|
let read = Arc::new(read);
|
||||||
|
|
||||||
|
let poller = Poller::new().unwrap();
|
||||||
|
unsafe {
|
||||||
|
#[cfg(unix)]
|
||||||
|
let read = read.as_raw_fd();
|
||||||
|
#[cfg(windows)]
|
||||||
|
let read = read.as_raw_socket();
|
||||||
|
|
||||||
|
poller.add(read, Event::readable(1)).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
poller.add(read, Event::readable(1)).unwrap_err().kind(),
|
||||||
|
io::ErrorKind::AlreadyExists
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
write.write_all(&[1]).unwrap();
|
||||||
|
let mut events = Events::new();
|
||||||
|
assert_eq!(
|
||||||
|
poller
|
||||||
|
.wait(&mut events, Some(Duration::from_secs(1)))
|
||||||
|
.unwrap(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
events.iter().next().unwrap().with_no_extra(),
|
||||||
|
Event::readable(1)
|
||||||
|
);
|
||||||
|
|
||||||
|
poller.delete(&read).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
|
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||||
let a = TcpStream::connect(listener.local_addr()?)?;
|
let a = TcpStream::connect(listener.local_addr()?)?;
|
||||||
|
|
Loading…
Reference in New Issue