breaking: Rework the API for I/O safety

* Rework the API for I/O safety

* Bump to rustix v0.38
This commit is contained in:
John Nunley 2023-08-03 20:15:59 -07:00 committed by GitHub
parent c86c3894c1
commit 6eb7679aa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 232 additions and 158 deletions

View File

@ -28,11 +28,6 @@ freebsd_task:
- sudo sysctl net.inet.tcp.blackhole=0 - sudo sysctl net.inet.tcp.blackhole=0
- . $HOME/.cargo/env - . $HOME/.cargo/env
- cargo test --target $TARGET - cargo test --target $TARGET
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --target $TARGET --manifest-path=async-io/Cargo.toml
netbsd_task: netbsd_task:
name: test ($TARGET) name: test ($TARGET)
@ -49,11 +44,6 @@ netbsd_task:
test_script: test_script:
- . $HOME/.cargo/env - . $HOME/.cargo/env
- cargo test - cargo test
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --manifest-path=async-io/Cargo.toml
openbsd_task: openbsd_task:
name: test ($TARGET) name: test ($TARGET)
@ -69,8 +59,3 @@ openbsd_task:
- pkg_add git rust - pkg_add git rust
test_script: test_script:
- cargo test - cargo test
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --manifest-path=async-io/Cargo.toml

View File

@ -56,14 +56,6 @@ jobs:
RUSTFLAGS: ${{ env.RUSTFLAGS }} --cfg polling_test_poll_backend RUSTFLAGS: ${{ env.RUSTFLAGS }} --cfg polling_test_poll_backend
if: startsWith(matrix.os, 'ubuntu') if: startsWith(matrix.os, 'ubuntu')
- run: cargo hack build --feature-powerset --no-dev-deps - run: cargo hack build --feature-powerset --no-dev-deps
- name: Clone async-io
run: git clone https://github.com/smol-rs/async-io.git
- name: Add patch section
run: echo '[patch.crates-io]' >> async-io/Cargo.toml
- name: Patch polling
run: echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- name: Test async-io
run: cargo test --manifest-path=async-io/Cargo.toml
cross: cross:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

View File

@ -27,7 +27,7 @@ tracing = { version = "0.1.37", default-features = false }
[target.'cfg(any(unix, target_os = "fuchsia", target_os = "vxworks"))'.dependencies] [target.'cfg(any(unix, target_os = "fuchsia", target_os = "vxworks"))'.dependencies]
libc = "0.2.77" libc = "0.2.77"
rustix = { version = "0.37.11", features = ["process", "time", "fs", "std"], default-features = false } rustix = { version = "0.38", features = ["event", "fs", "pipe", "process", "std", "time"], default-features = false }
[target.'cfg(windows)'.dependencies] [target.'cfg(windows)'.dependencies]
concurrent-queue = "2.2.0" concurrent-queue = "2.2.0"

View File

@ -10,8 +10,10 @@ fn main() -> io::Result<()> {
l2.set_nonblocking(true)?; l2.set_nonblocking(true)?;
let poller = Poller::new()?; let poller = Poller::new()?;
poller.add(&l1, Event::readable(1))?; unsafe {
poller.add(&l2, Event::readable(2))?; poller.add(&l1, Event::readable(1))?;
poller.add(&l2, Event::readable(2))?;
}
println!("You can connect to the server using `nc`:"); println!("You can connect to the server using `nc`:");
println!(" $ nc 127.0.0.1 8001"); println!(" $ nc 127.0.0.1 8001");

View File

@ -5,8 +5,9 @@ use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd}; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::time::Duration; use std::time::Duration;
use rustix::event::{epoll, eventfd, EventfdFlags};
use rustix::fd::OwnedFd; use rustix::fd::OwnedFd;
use rustix::io::{epoll, eventfd, read, write, EventfdFlags}; use rustix::io::{read, write};
use rustix::time::{ use rustix::time::{
timerfd_create, timerfd_settime, Itimerspec, TimerfdClockId, TimerfdFlags, TimerfdTimerFlags, timerfd_create, timerfd_settime, Itimerspec, TimerfdClockId, TimerfdFlags, TimerfdTimerFlags,
Timespec, Timespec,
@ -31,7 +32,7 @@ impl Poller {
// Create an epoll instance. // Create an epoll instance.
// //
// Use `epoll_create1` with `EPOLL_CLOEXEC`. // Use `epoll_create1` with `EPOLL_CLOEXEC`.
let epoll_fd = epoll::epoll_create(epoll::CreateFlags::CLOEXEC)?; let epoll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)?;
// Set up eventfd and timerfd. // Set up eventfd and timerfd.
let event_fd = eventfd(0, EventfdFlags::CLOEXEC | EventfdFlags::NONBLOCK)?; let event_fd = eventfd(0, EventfdFlags::CLOEXEC | EventfdFlags::NONBLOCK)?;
@ -47,24 +48,26 @@ impl Poller {
timer_fd, timer_fd,
}; };
if let Some(ref timer_fd) = poller.timer_fd { unsafe {
if let Some(ref timer_fd) = poller.timer_fd {
poller.add(
timer_fd.as_raw_fd(),
Event::none(crate::NOTIFY_KEY),
PollMode::Oneshot,
)?;
}
poller.add( poller.add(
timer_fd.as_raw_fd(), poller.event_fd.as_raw_fd(),
Event::none(crate::NOTIFY_KEY), Event {
key: crate::NOTIFY_KEY,
readable: true,
writable: false,
},
PollMode::Oneshot, PollMode::Oneshot,
)?; )?;
} }
poller.add(
poller.event_fd.as_raw_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
writable: false,
},
PollMode::Oneshot,
)?;
tracing::trace!( tracing::trace!(
epoll_fd = ?poller.epoll_fd.as_raw_fd(), epoll_fd = ?poller.epoll_fd.as_raw_fd(),
event_fd = ?poller.event_fd.as_raw_fd(), event_fd = ?poller.event_fd.as_raw_fd(),
@ -85,7 +88,12 @@ impl Poller {
} }
/// Adds a new file descriptor. /// Adds a new file descriptor.
pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { ///
/// # Safety
///
/// The `fd` must be a valid file descriptor. The usual condition of remaining registered in
/// the `Poller` doesn't apply to `epoll`.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"add", "add",
epoll_fd = ?self.epoll_fd.as_raw_fd(), epoll_fd = ?self.epoll_fd.as_raw_fd(),
@ -94,10 +102,10 @@ impl Poller {
); );
let _enter = span.enter(); let _enter = span.enter();
epoll::epoll_add( epoll::add(
&self.epoll_fd, &self.epoll_fd,
unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) }, unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) },
ev.key as u64, epoll::EventData::new_u64(ev.key as u64),
epoll_flags(&ev, mode), epoll_flags(&ev, mode),
)?; )?;
@ -105,7 +113,7 @@ impl Poller {
} }
/// Modifies an existing file descriptor. /// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"modify", "modify",
epoll_fd = ?self.epoll_fd.as_raw_fd(), epoll_fd = ?self.epoll_fd.as_raw_fd(),
@ -114,10 +122,10 @@ impl Poller {
); );
let _enter = span.enter(); let _enter = span.enter();
epoll::epoll_mod( epoll::modify(
&self.epoll_fd, &self.epoll_fd,
unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) }, fd,
ev.key as u64, epoll::EventData::new_u64(ev.key as u64),
epoll_flags(&ev, mode), epoll_flags(&ev, mode),
)?; )?;
@ -125,7 +133,7 @@ impl Poller {
} }
/// Deletes a file descriptor. /// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> { pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"delete", "delete",
epoll_fd = ?self.epoll_fd.as_raw_fd(), epoll_fd = ?self.epoll_fd.as_raw_fd(),
@ -133,9 +141,7 @@ impl Poller {
); );
let _enter = span.enter(); let _enter = span.enter();
epoll::epoll_del(&self.epoll_fd, unsafe { epoll::delete(&self.epoll_fd, fd)?;
rustix::fd::BorrowedFd::borrow_raw(fd)
})?;
Ok(()) Ok(())
} }
@ -170,7 +176,7 @@ impl Poller {
// Set interest in timerfd. // Set interest in timerfd.
self.modify( self.modify(
timer_fd.as_raw_fd(), timer_fd.as_fd(),
Event { Event {
key: crate::NOTIFY_KEY, key: crate::NOTIFY_KEY,
readable: true, readable: true,
@ -195,7 +201,7 @@ impl Poller {
}; };
// Wait for I/O events. // Wait for I/O events.
epoll::epoll_wait(&self.epoll_fd, &mut events.list, timeout_ms)?; epoll::wait(&self.epoll_fd, &mut events.list, timeout_ms)?;
tracing::trace!( tracing::trace!(
epoll_fd = ?self.epoll_fd.as_raw_fd(), epoll_fd = ?self.epoll_fd.as_raw_fd(),
res = ?events.list.len(), res = ?events.list.len(),
@ -206,7 +212,7 @@ impl Poller {
let mut buf = [0u8; 8]; let mut buf = [0u8; 8];
let _ = read(&self.event_fd, &mut buf); let _ = read(&self.event_fd, &mut buf);
self.modify( self.modify(
self.event_fd.as_raw_fd(), self.event_fd.as_fd(),
Event { Event {
key: crate::NOTIFY_KEY, key: crate::NOTIFY_KEY,
readable: true, readable: true,
@ -255,9 +261,9 @@ impl Drop for Poller {
let _enter = span.enter(); let _enter = span.enter();
if let Some(timer_fd) = self.timer_fd.take() { if let Some(timer_fd) = self.timer_fd.take() {
let _ = self.delete(timer_fd.as_raw_fd()); let _ = self.delete(timer_fd.as_fd());
} }
let _ = self.delete(self.event_fd.as_raw_fd()); let _ = self.delete(self.event_fd.as_fd());
} }
} }
@ -310,10 +316,13 @@ impl Events {
/// Iterates over I/O events. /// Iterates over I/O events.
pub fn iter(&self) -> impl Iterator<Item = Event> + '_ { pub fn iter(&self) -> impl Iterator<Item = Event> + '_ {
self.list.iter().map(|(flags, data)| Event { self.list.iter().map(|ev| {
key: data as usize, let flags = ev.flags;
readable: flags.intersects(read_flags()), Event {
writable: flags.intersects(write_flags()), key: ev.data.u64() as usize,
readable: flags.intersects(read_flags()),
writable: flags.intersects(write_flags()),
}
}) })
} }
} }

View File

@ -42,7 +42,9 @@ use std::collections::hash_map::{Entry, HashMap};
use std::fmt; use std::fmt;
use std::io; use std::io;
use std::marker::PhantomPinned; use std::marker::PhantomPinned;
use std::os::windows::io::{AsHandle, AsRawHandle, BorrowedHandle, RawHandle, RawSocket}; use std::os::windows::io::{
AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket,
};
use std::pin::Pin; use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
@ -134,7 +136,16 @@ impl Poller {
} }
/// Add a new source to the poller. /// Add a new source to the poller.
pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> { ///
/// # Safety
///
/// The socket must be a valid socket and must last until it is deleted.
pub(super) unsafe fn add(
&self,
socket: RawSocket,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"add", "add",
handle = ?self.port, handle = ?self.port,
@ -192,7 +203,7 @@ impl Poller {
/// Update a source in the poller. /// Update a source in the poller.
pub(super) fn modify( pub(super) fn modify(
&self, &self,
socket: RawSocket, socket: BorrowedSocket<'_>,
interest: Event, interest: Event,
mode: PollMode, mode: PollMode,
) -> io::Result<()> { ) -> io::Result<()> {
@ -217,7 +228,7 @@ impl Poller {
let sources = lock!(self.sources.read()); let sources = lock!(self.sources.read());
sources sources
.get(&socket) .get(&socket.as_raw_socket())
.cloned() .cloned()
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
}; };
@ -231,7 +242,7 @@ impl Poller {
} }
/// Delete a source from the poller. /// Delete a source from the poller.
pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> { pub(super) fn delete(&self, socket: BorrowedSocket<'_>) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"remove", "remove",
handle = ?self.port, handle = ?self.port,
@ -243,7 +254,7 @@ impl Poller {
let source = { let source = {
let mut sources = lock!(self.sources.write()); let mut sources = lock!(self.sources.write());
match sources.remove(&socket) { match sources.remove(&socket.as_raw_socket()) {
Some(s) => s, Some(s) => s,
None => { None => {
// If the source has already been removed, then we can just return. // If the source has already been removed, then we can just return.

View File

@ -1,11 +1,11 @@
//! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD). //! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD).
use std::io; use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd}; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use std::time::Duration; use std::time::Duration;
use rustix::fd::OwnedFd; use rustix::event::kqueue;
use rustix::io::{fcntl_setfd, kqueue, Errno, FdFlags}; use rustix::io::{fcntl_setfd, Errno, FdFlags};
use crate::{Event, PollMode}; use crate::{Event, PollMode};
@ -55,13 +55,17 @@ impl Poller {
} }
/// Adds a new file descriptor. /// Adds a new file descriptor.
pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { ///
/// # Safety
///
/// The file descriptor must be valid and it must last until it is deleted.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
// File descriptors don't need to be added explicitly, so just modify the interest. // File descriptors don't need to be added explicitly, so just modify the interest.
self.modify(fd, ev, mode) self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
} }
/// Modifies an existing file descriptor. /// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = if !self.notify.has_fd(fd) { let span = if !self.notify.has_fd(fd) {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"add", "add",
@ -91,12 +95,12 @@ impl Poller {
// A list of changes for kqueue. // A list of changes for kqueue.
let changelist = [ let changelist = [
kqueue::Event::new( kqueue::Event::new(
kqueue::EventFilter::Read(fd), kqueue::EventFilter::Read(fd.as_raw_fd()),
read_flags | kqueue::EventFlags::RECEIPT, read_flags | kqueue::EventFlags::RECEIPT,
ev.key as _, ev.key as _,
), ),
kqueue::Event::new( kqueue::Event::new(
kqueue::EventFilter::Write(fd), kqueue::EventFilter::Write(fd.as_raw_fd()),
write_flags | kqueue::EventFlags::RECEIPT, write_flags | kqueue::EventFlags::RECEIPT,
ev.key as _, ev.key as _,
), ),
@ -141,7 +145,7 @@ impl Poller {
} }
/// Deletes a file descriptor. /// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> { pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
// Simply delete interest in the file descriptor. // Simply delete interest in the file descriptor.
self.modify(fd, Event::none(0), PollMode::Oneshot) self.modify(fd, Event::none(0), PollMode::Oneshot)
} }
@ -268,9 +272,9 @@ pub(crate) fn mode_to_flags(mode: PollMode) -> kqueue::EventFlags {
))] ))]
mod notify { mod notify {
use super::Poller; use super::Poller;
use rustix::io::kqueue; use rustix::event::kqueue;
use std::io; use std::io;
use std::os::unix::io::RawFd; use std::os::unix::io::BorrowedFd;
/// A notification pipe. /// A notification pipe.
/// ///
@ -335,7 +339,7 @@ mod notify {
} }
/// Whether this raw file descriptor is associated with this pipe. /// Whether this raw file descriptor is associated with this pipe.
pub(super) fn has_fd(&self, _fd: RawFd) -> bool { pub(super) fn has_fd(&self, _fd: BorrowedFd<'_>) -> bool {
false false
} }
} }
@ -354,7 +358,7 @@ mod notify {
use crate::{Event, PollMode, NOTIFY_KEY}; use crate::{Event, PollMode, NOTIFY_KEY};
use std::io::{self, prelude::*}; use std::io::{self, prelude::*};
use std::os::unix::{ use std::os::unix::{
io::{AsRawFd, RawFd}, io::{AsFd, AsRawFd, BorrowedFd},
net::UnixStream, net::UnixStream,
}; };
@ -386,11 +390,13 @@ mod notify {
/// Registers this notification pipe in the `Poller`. /// Registers this notification pipe in the `Poller`.
pub(super) fn register(&self, poller: &Poller) -> io::Result<()> { pub(super) fn register(&self, poller: &Poller) -> io::Result<()> {
// Register the read end of this pipe. // Register the read end of this pipe.
poller.add( unsafe {
self.read_stream.as_raw_fd(), poller.add(
Event::readable(NOTIFY_KEY), self.read_stream.as_raw_fd(),
PollMode::Oneshot, Event::readable(NOTIFY_KEY),
) PollMode::Oneshot,
)
}
} }
/// Reregister this notification pipe in the `Poller`. /// Reregister this notification pipe in the `Poller`.
@ -400,7 +406,7 @@ mod notify {
// Reregister the read end of this pipe. // Reregister the read end of this pipe.
poller.modify( poller.modify(
self.read_stream.as_raw_fd(), self.read_stream.as_fd(),
Event::readable(NOTIFY_KEY), Event::readable(NOTIFY_KEY),
PollMode::Oneshot, PollMode::Oneshot,
) )
@ -418,12 +424,12 @@ mod notify {
/// Deregisters this notification pipe from the `Poller`. /// Deregisters this notification pipe from the `Poller`.
pub(super) fn deregister(&self, poller: &Poller) -> io::Result<()> { pub(super) fn deregister(&self, poller: &Poller) -> io::Result<()> {
// Deregister the read end of the pipe. // Deregister the read end of the pipe.
poller.delete(self.read_stream.as_raw_fd()) poller.delete(self.read_stream.as_fd())
} }
/// Whether this raw file descriptor is associated with this pipe. /// Whether this raw file descriptor is associated with this pipe.
pub(super) fn has_fd(&self, fd: RawFd) -> bool { pub(super) fn has_fd(&self, fd: BorrowedFd<'_>) -> bool {
self.read_stream.as_raw_fd() == fd self.read_stream.as_raw_fd() == fd.as_raw_fd()
} }
} }
} }

View File

@ -28,7 +28,9 @@
//! //!
//! // Create a poller and register interest in readability on the socket. //! // Create a poller and register interest in readability on the socket.
//! let poller = Poller::new()?; //! let poller = Poller::new()?;
//! poller.add(&socket, Event::readable(key))?; //! unsafe {
//! poller.add(&socket, Event::readable(key))?;
//! }
//! //!
//! // The event loop. //! // The event loop.
//! let mut events = Vec::new(); //! let mut events = Vec::new();
@ -46,13 +48,15 @@
//! } //! }
//! } //! }
//! } //! }
//!
//! poller.delete(&socket)?;
//! # std::io::Result::Ok(()) //! # std::io::Result::Ok(())
//! ``` //! ```
#![cfg(feature = "std")] #![cfg(feature = "std")]
#![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] #![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
#![allow(clippy::useless_conversion, clippy::unnecessary_cast)] #![allow(clippy::useless_conversion, clippy::unnecessary_cast, unused_unsafe)]
#![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, feature(doc_cfg))]
#![doc( #![doc(
html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png" html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
@ -273,8 +277,11 @@ impl Poller {
/// [`modify()`][`Poller::modify()`] again after an event is delivered if we're interested in /// [`modify()`][`Poller::modify()`] again after an event is delivered if we're interested in
/// the next event of the same kind. /// the next event of the same kind.
/// ///
/// Don't forget to [`delete()`][`Poller::delete()`] the file descriptor or socket when it is /// # Safety
/// no longer used! ///
/// The source must be [`delete()`]d from this `Poller` before it is dropped.
///
/// [`delete()`]: Poller::delete
/// ///
/// # Errors /// # Errors
/// ///
@ -295,10 +302,13 @@ impl Poller {
/// let key = 7; /// let key = 7;
/// ///
/// let poller = Poller::new()?; /// let poller = Poller::new()?;
/// poller.add(&source, Event::all(key))?; /// unsafe {
/// poller.add(&source, Event::all(key))?;
/// }
/// poller.delete(&source)?;
/// # std::io::Result::Ok(()) /// # std::io::Result::Ok(())
/// ``` /// ```
pub fn add(&self, source: impl Source, interest: Event) -> io::Result<()> { pub unsafe fn add(&self, source: impl AsRawSource, interest: Event) -> io::Result<()> {
self.add_with_mode(source, interest, PollMode::Oneshot) self.add_with_mode(source, interest, PollMode::Oneshot)
} }
@ -307,13 +317,19 @@ impl Poller {
/// This is identical to the `add()` function, but allows specifying the /// This is identical to the `add()` function, but allows specifying the
/// polling mode to use for this socket. /// polling mode to use for this socket.
/// ///
/// # Safety
///
/// The source must be [`delete()`]d from this `Poller` before it is dropped.
///
/// [`delete()`]: Poller::delete
///
/// # Errors /// # Errors
/// ///
/// If the operating system does not support the specified mode, this function /// If the operating system does not support the specified mode, this function
/// will return an error. /// will return an error.
pub fn add_with_mode( pub unsafe fn add_with_mode(
&self, &self,
source: impl Source, source: impl AsRawSource,
interest: Event, interest: Event,
mode: PollMode, mode: PollMode,
) -> io::Result<()> { ) -> io::Result<()> {
@ -354,7 +370,7 @@ impl Poller {
/// # let source = std::net::TcpListener::bind("127.0.0.1:0")?; /// # let source = std::net::TcpListener::bind("127.0.0.1:0")?;
/// # let key = 7; /// # let key = 7;
/// # let poller = Poller::new()?; /// # let poller = Poller::new()?;
/// # poller.add(&source, Event::none(key))?; /// # unsafe { poller.add(&source, Event::none(key))?; }
/// poller.modify(&source, Event::all(key))?; /// poller.modify(&source, Event::all(key))?;
/// # std::io::Result::Ok(()) /// # std::io::Result::Ok(())
/// ``` /// ```
@ -366,8 +382,9 @@ impl Poller {
/// # let source = std::net::TcpListener::bind("127.0.0.1:0")?; /// # let source = std::net::TcpListener::bind("127.0.0.1:0")?;
/// # let key = 7; /// # let key = 7;
/// # let poller = Poller::new()?; /// # let poller = Poller::new()?;
/// # poller.add(&source, Event::none(key))?; /// # unsafe { poller.add(&source, Event::none(key))?; }
/// poller.modify(&source, Event::readable(key))?; /// poller.modify(&source, Event::readable(key))?;
/// # poller.delete(&source)?;
/// # std::io::Result::Ok(()) /// # std::io::Result::Ok(())
/// ``` /// ```
/// ///
@ -378,8 +395,9 @@ impl Poller {
/// # let poller = Poller::new()?; /// # let poller = Poller::new()?;
/// # let key = 7; /// # let key = 7;
/// # let source = std::net::TcpListener::bind("127.0.0.1:0")?; /// # let source = std::net::TcpListener::bind("127.0.0.1:0")?;
/// # poller.add(&source, Event::none(key))?; /// # unsafe { poller.add(&source, Event::none(key))? };
/// poller.modify(&source, Event::writable(key))?; /// poller.modify(&source, Event::writable(key))?;
/// # poller.delete(&source)?;
/// # std::io::Result::Ok(()) /// # std::io::Result::Ok(())
/// ``` /// ```
/// ///
@ -390,11 +408,12 @@ impl Poller {
/// # let source = std::net::TcpListener::bind("127.0.0.1:0")?; /// # let source = std::net::TcpListener::bind("127.0.0.1:0")?;
/// # let key = 7; /// # let key = 7;
/// # let poller = Poller::new()?; /// # let poller = Poller::new()?;
/// # poller.add(&source, Event::none(key))?; /// # unsafe { poller.add(&source, Event::none(key))?; }
/// poller.modify(&source, Event::none(key))?; /// poller.modify(&source, Event::none(key))?;
/// # poller.delete(&source)?;
/// # std::io::Result::Ok(()) /// # std::io::Result::Ok(())
/// ``` /// ```
pub fn modify(&self, source: impl Source, interest: Event) -> io::Result<()> { pub fn modify(&self, source: impl AsSource, interest: Event) -> io::Result<()> {
self.modify_with_mode(source, interest, PollMode::Oneshot) self.modify_with_mode(source, interest, PollMode::Oneshot)
} }
@ -415,7 +434,7 @@ impl Poller {
/// an error. /// an error.
pub fn modify_with_mode( pub fn modify_with_mode(
&self, &self,
source: impl Source, source: impl AsSource,
interest: Event, interest: Event,
mode: PollMode, mode: PollMode,
) -> io::Result<()> { ) -> io::Result<()> {
@ -425,7 +444,7 @@ impl Poller {
"the key is not allowed to be `usize::MAX`", "the key is not allowed to be `usize::MAX`",
)); ));
} }
self.poller.modify(source.raw(), interest, mode) self.poller.modify(source.source(), interest, mode)
} }
/// Removes a file descriptor or socket from the poller. /// Removes a file descriptor or socket from the poller.
@ -444,12 +463,12 @@ impl Poller {
/// let key = 7; /// let key = 7;
/// ///
/// let poller = Poller::new()?; /// let poller = Poller::new()?;
/// poller.add(&socket, Event::all(key))?; /// unsafe { poller.add(&socket, Event::all(key))?; }
/// poller.delete(&socket)?; /// poller.delete(&socket)?;
/// # std::io::Result::Ok(()) /// # std::io::Result::Ok(())
/// ``` /// ```
pub fn delete(&self, source: impl Source) -> io::Result<()> { pub fn delete(&self, source: impl AsSource) -> io::Result<()> {
self.poller.delete(source.raw()) self.poller.delete(source.source())
} }
/// Waits for at least one I/O event and returns the number of new events. /// Waits for at least one I/O event and returns the number of new events.
@ -482,10 +501,13 @@ impl Poller {
/// let key = 7; /// let key = 7;
/// ///
/// let poller = Poller::new()?; /// let poller = Poller::new()?;
/// poller.add(&socket, Event::all(key))?; /// unsafe {
/// poller.add(&socket, Event::all(key))?;
/// }
/// ///
/// let mut events = Vec::new(); /// let mut events = Vec::new();
/// let n = poller.wait(&mut events, Some(Duration::from_secs(1)))?; /// let n = poller.wait(&mut events, Some(Duration::from_secs(1)))?;
/// poller.delete(&socket)?;
/// # std::io::Result::Ok(()) /// # std::io::Result::Ok(())
/// ``` /// ```
pub fn wait(&self, events: &mut Vec<Event>, timeout: Option<Duration>) -> io::Result<usize> { pub fn wait(&self, events: &mut Vec<Event>, timeout: Option<Duration>) -> io::Result<usize> {
@ -624,45 +646,65 @@ impl fmt::Debug for Poller {
cfg_if! { cfg_if! {
if #[cfg(unix)] { if #[cfg(unix)] {
use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::io::{AsRawFd, RawFd, AsFd, BorrowedFd};
/// A [`RawFd`] or a reference to a type implementing [`AsRawFd`]. /// A resource with a raw file descriptor.
pub trait Source { pub trait AsRawSource {
/// Returns the [`RawFd`] for this I/O object. /// Returns the raw file descriptor.
fn raw(&self) -> RawFd; fn raw(&self) -> RawFd;
} }
impl Source for RawFd { impl<T: AsRawFd> AsRawSource for &T {
fn raw(&self) -> RawFd {
*self
}
}
impl<T: AsRawFd> Source for &T {
fn raw(&self) -> RawFd { fn raw(&self) -> RawFd {
self.as_raw_fd() self.as_raw_fd()
} }
} }
} else if #[cfg(windows)] {
use std::os::windows::io::{AsRawSocket, RawSocket};
/// A [`RawSocket`] or a reference to a type implementing [`AsRawSocket`]. impl AsRawSource for RawFd {
pub trait Source { fn raw(&self) -> RawFd {
/// Returns the [`RawSocket`] for this I/O object. *self
}
}
/// A resource with a borrowed file descriptor.
pub trait AsSource: AsFd {
/// Returns the borrowed file descriptor.
fn source(&self) -> BorrowedFd<'_> {
self.as_fd()
}
}
impl<T: AsFd> AsSource for T {}
} else if #[cfg(windows)] {
use std::os::windows::io::{AsRawSocket, RawSocket, AsSocket, BorrowedSocket};
/// A resource with a raw socket.
pub trait AsRawSource {
/// Returns the raw socket.
fn raw(&self) -> RawSocket; fn raw(&self) -> RawSocket;
} }
impl Source for RawSocket { impl<T: AsRawSocket> AsRawSource for &T {
fn raw(&self) -> RawSocket {
self.as_raw_socket()
}
}
impl AsRawSource for RawSocket {
fn raw(&self) -> RawSocket { fn raw(&self) -> RawSocket {
*self *self
} }
} }
impl<T: AsRawSocket> Source for &T { /// A resource with a borrowed socket.
fn raw(&self) -> RawSocket { pub trait AsSource: AsSocket {
self.as_raw_socket() /// Returns the borrowed socket.
fn source(&self) -> BorrowedSocket<'_> {
self.as_socket()
} }
} }
impl<T: AsSocket> AsSource for T {}
} }
} }

View File

@ -7,7 +7,7 @@ use std::io;
use std::process::Child; use std::process::Child;
use std::time::Duration; use std::time::Duration;
use rustix::io::kqueue; use rustix::event::kqueue;
use super::__private::PollerSealed; use super::__private::PollerSealed;
use __private::FilterSealed; use __private::FilterSealed;
@ -238,7 +238,7 @@ unsafe impl FilterSealed for Timer {
impl Filter for Timer {} impl Filter for Timer {}
mod __private { mod __private {
use rustix::io::kqueue; use rustix::event::kqueue;
#[doc(hidden)] #[doc(hidden)]
pub unsafe trait FilterSealed { pub unsafe trait FilterSealed {

View File

@ -7,12 +7,11 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Condvar, Mutex}; use std::sync::{Condvar, Mutex};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use rustix::event::{poll, PollFd, PollFlags};
use rustix::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd}; use rustix::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd};
use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags}; use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags};
use rustix::io::{ use rustix::io::{fcntl_getfd, fcntl_setfd, read, write, FdFlags};
fcntl_getfd, fcntl_setfd, pipe, pipe_with, poll, read, write, FdFlags, PipeFlags, PollFd, use rustix::pipe::{pipe, pipe_with, PipeFlags};
PollFlags,
};
// std::os::unix doesn't exist on Fuchsia // std::os::unix doesn't exist on Fuchsia
type RawFd = std::os::raw::c_int; type RawFd = std::os::raw::c_int;
@ -158,7 +157,7 @@ impl Poller {
} }
/// Modifies an existing file descriptor. /// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"modify", "modify",
notify_read = ?self.notify_read, notify_read = ?self.notify_read,
@ -168,11 +167,19 @@ impl Poller {
let _enter = span.enter(); let _enter = span.enter();
self.modify_fds(|fds| { self.modify_fds(|fds| {
let data = fds.fd_data.get_mut(&fd).ok_or(io::ErrorKind::NotFound)?; let data = fds
.fd_data
.get_mut(&fd.as_raw_fd())
.ok_or(io::ErrorKind::NotFound)?;
data.key = ev.key; data.key = ev.key;
let poll_fds_index = data.poll_fds_index; let poll_fds_index = data.poll_fds_index;
fds.poll_fds[poll_fds_index] =
PollFd::from_borrowed_fd(unsafe { BorrowedFd::borrow_raw(fd) }, poll_events(ev)); // SAFETY: This is essentially transmuting a `PollFd<'a>` to a `PollFd<'static>`, which
// only works if it's removed in time with `delete()`.
fds.poll_fds[poll_fds_index] = PollFd::from_borrowed_fd(
unsafe { BorrowedFd::borrow_raw(fd.as_raw_fd()) },
poll_events(ev),
);
data.remove = cvt_mode_as_remove(mode)?; data.remove = cvt_mode_as_remove(mode)?;
Ok(()) Ok(())
@ -180,7 +187,7 @@ impl Poller {
} }
/// Deletes a file descriptor. /// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> { pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"delete", "delete",
notify_read = ?self.notify_read, notify_read = ?self.notify_read,
@ -189,7 +196,10 @@ impl Poller {
let _enter = span.enter(); let _enter = span.enter();
self.modify_fds(|fds| { self.modify_fds(|fds| {
let data = fds.fd_data.remove(&fd).ok_or(io::ErrorKind::NotFound)?; let data = fds
.fd_data
.remove(&fd.as_raw_fd())
.ok_or(io::ErrorKind::NotFound)?;
fds.poll_fds.swap_remove(data.poll_fds_index); fds.poll_fds.swap_remove(data.poll_fds_index);
if let Some(swapped_pollfd) = fds.poll_fds.get(data.poll_fds_index) { if let Some(swapped_pollfd) = fds.poll_fds.get(data.poll_fds_index) {
fds.fd_data fds.fd_data

View File

@ -4,8 +4,9 @@ use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd}; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::time::Duration; use std::time::Duration;
use rustix::event::{port, PollFlags};
use rustix::fd::OwnedFd; use rustix::fd::OwnedFd;
use rustix::io::{fcntl_getfd, fcntl_setfd, port, FdFlags, PollFlags}; use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
use crate::{Event, PollMode}; use crate::{Event, PollMode};
@ -42,13 +43,17 @@ impl Poller {
} }
/// Adds a file descriptor. /// Adds a file descriptor.
pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { ///
/// # Safety
///
/// The `fd` must be a valid file descriptor and it must last until it is deleted.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
// File descriptors don't need to be added explicitly, so just modify the interest. // File descriptors don't need to be added explicitly, so just modify the interest.
self.modify(fd, ev, mode) self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
} }
/// Modifies an existing file descriptor. /// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"modify", "modify",
port_fd = ?self.port_fd.as_raw_fd(), port_fd = ?self.port_fd.as_raw_fd(),
@ -79,7 +84,7 @@ impl Poller {
} }
/// Deletes a file descriptor. /// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> { pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
let span = tracing::trace_span!( let span = tracing::trace_span!(
"delete", "delete",
port_fd = ?self.port_fd.as_raw_fd(), port_fd = ?self.port_fd.as_raw_fd(),

View File

@ -13,20 +13,25 @@ fn concurrent_add() -> io::Result<()> {
let mut events = Vec::new(); let mut events = Vec::new();
Parallel::new() let result = Parallel::new()
.add(|| { .add(|| {
poller.wait(&mut events, None)?; poller.wait(&mut events, None)?;
Ok(()) Ok(())
}) })
.add(|| { .add(|| {
thread::sleep(Duration::from_millis(100)); thread::sleep(Duration::from_millis(100));
poller.add(&reader, Event::readable(0))?; unsafe {
poller.add(&reader, Event::readable(0))?;
}
writer.write_all(&[1])?; writer.write_all(&[1])?;
Ok(()) Ok(())
}) })
.run() .run()
.into_iter() .into_iter()
.collect::<io::Result<()>>()?; .collect::<io::Result<()>>();
poller.delete(&reader)?;
result?;
assert_eq!(events, [Event::readable(0)]); assert_eq!(events, [Event::readable(0)]);
@ -37,7 +42,9 @@ fn concurrent_add() -> io::Result<()> {
fn concurrent_modify() -> io::Result<()> { fn concurrent_modify() -> io::Result<()> {
let (reader, mut writer) = tcp_pair()?; let (reader, mut writer) = tcp_pair()?;
let poller = Poller::new()?; let poller = Poller::new()?;
poller.add(&reader, Event::none(0))?; unsafe {
poller.add(&reader, Event::none(0))?;
}
let mut events = Vec::new(); let mut events = Vec::new();

View File

@ -7,7 +7,9 @@ use std::time::Duration;
fn basic_io() { fn basic_io() {
let poller = Poller::new().unwrap(); let poller = Poller::new().unwrap();
let (read, mut write) = tcp_pair().unwrap(); let (read, mut write) = tcp_pair().unwrap();
poller.add(&read, Event::readable(1)).unwrap(); unsafe {
poller.add(&read, Event::readable(1)).unwrap();
}
// Nothing should be available at first. // Nothing should be available at first.
let mut events = vec![]; let mut events = vec![];
@ -28,6 +30,8 @@ fn basic_io() {
1 1
); );
assert_eq!(&*events, &[Event::readable(1)]); assert_eq!(&*events, &[Event::readable(1)]);
poller.delete(&read).unwrap();
} }
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> { fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {

View File

@ -20,7 +20,9 @@ fn many_connections() {
let poller = polling::Poller::new().unwrap(); let poller = polling::Poller::new().unwrap();
for (i, reader, _) in connections.iter() { for (i, reader, _) in connections.iter() {
poller.add(reader, polling::Event::readable(*i)).unwrap(); unsafe {
poller.add(reader, polling::Event::readable(*i)).unwrap();
}
} }
let mut events = vec![]; let mut events = vec![];

View File

@ -16,8 +16,7 @@ fn level_triggered() {
// Create our poller and register our streams. // Create our poller and register our streams.
let poller = Poller::new().unwrap(); let poller = Poller::new().unwrap();
if poller if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Level) }
.add_with_mode(&reader, Event::readable(reader_token), PollMode::Level)
.is_err() .is_err()
{ {
// Only panic if we're on a platform that should support level mode. // Only panic if we're on a platform that should support level mode.
@ -92,8 +91,7 @@ fn edge_triggered() {
// Create our poller and register our streams. // Create our poller and register our streams.
let poller = Poller::new().unwrap(); let poller = Poller::new().unwrap();
if poller if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge) }
.add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge)
.is_err() .is_err()
{ {
// Only panic if we're on a platform that should support level mode. // Only panic if we're on a platform that should support level mode.
@ -170,13 +168,14 @@ fn edge_oneshot_triggered() {
// Create our poller and register our streams. // Create our poller and register our streams.
let poller = Poller::new().unwrap(); let poller = Poller::new().unwrap();
if poller if unsafe {
.add_with_mode( poller.add_with_mode(
&reader, &reader,
Event::readable(reader_token), Event::readable(reader_token),
PollMode::EdgeOneshot, PollMode::EdgeOneshot,
) )
.is_err() }
.is_err()
{ {
// Only panic if we're on a platform that should support level mode. // Only panic if we're on a platform that should support level mode.
cfg_if::cfg_if! { cfg_if::cfg_if! {