breaking: Rework the API for I/O safety

* Rework the API for I/O safety

* Bump to rustix v0.38
This commit is contained in:
John Nunley 2023-08-03 20:15:59 -07:00 committed by GitHub
parent c86c3894c1
commit 6eb7679aa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 232 additions and 158 deletions

View File

@ -28,11 +28,6 @@ freebsd_task:
- sudo sysctl net.inet.tcp.blackhole=0
- . $HOME/.cargo/env
- cargo test --target $TARGET
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --target $TARGET --manifest-path=async-io/Cargo.toml
netbsd_task:
name: test ($TARGET)
@ -49,11 +44,6 @@ netbsd_task:
test_script:
- . $HOME/.cargo/env
- cargo test
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --manifest-path=async-io/Cargo.toml
openbsd_task:
name: test ($TARGET)
@ -69,8 +59,3 @@ openbsd_task:
- pkg_add git rust
test_script:
- cargo test
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --manifest-path=async-io/Cargo.toml

View File

@ -56,14 +56,6 @@ jobs:
RUSTFLAGS: ${{ env.RUSTFLAGS }} --cfg polling_test_poll_backend
if: startsWith(matrix.os, 'ubuntu')
- run: cargo hack build --feature-powerset --no-dev-deps
- name: Clone async-io
run: git clone https://github.com/smol-rs/async-io.git
- name: Add patch section
run: echo '[patch.crates-io]' >> async-io/Cargo.toml
- name: Patch polling
run: echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- name: Test async-io
run: cargo test --manifest-path=async-io/Cargo.toml
cross:
runs-on: ${{ matrix.os }}

View File

@ -27,7 +27,7 @@ tracing = { version = "0.1.37", default-features = false }
[target.'cfg(any(unix, target_os = "fuchsia", target_os = "vxworks"))'.dependencies]
libc = "0.2.77"
rustix = { version = "0.37.11", features = ["process", "time", "fs", "std"], default-features = false }
rustix = { version = "0.38", features = ["event", "fs", "pipe", "process", "std", "time"], default-features = false }
[target.'cfg(windows)'.dependencies]
concurrent-queue = "2.2.0"

View File

@ -10,8 +10,10 @@ fn main() -> io::Result<()> {
l2.set_nonblocking(true)?;
let poller = Poller::new()?;
poller.add(&l1, Event::readable(1))?;
poller.add(&l2, Event::readable(2))?;
unsafe {
poller.add(&l1, Event::readable(1))?;
poller.add(&l2, Event::readable(2))?;
}
println!("You can connect to the server using `nc`:");
println!(" $ nc 127.0.0.1 8001");

View File

@ -5,8 +5,9 @@ use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::time::Duration;
use rustix::event::{epoll, eventfd, EventfdFlags};
use rustix::fd::OwnedFd;
use rustix::io::{epoll, eventfd, read, write, EventfdFlags};
use rustix::io::{read, write};
use rustix::time::{
timerfd_create, timerfd_settime, Itimerspec, TimerfdClockId, TimerfdFlags, TimerfdTimerFlags,
Timespec,
@ -31,7 +32,7 @@ impl Poller {
// Create an epoll instance.
//
// Use `epoll_create1` with `EPOLL_CLOEXEC`.
let epoll_fd = epoll::epoll_create(epoll::CreateFlags::CLOEXEC)?;
let epoll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)?;
// Set up eventfd and timerfd.
let event_fd = eventfd(0, EventfdFlags::CLOEXEC | EventfdFlags::NONBLOCK)?;
@ -47,24 +48,26 @@ impl Poller {
timer_fd,
};
if let Some(ref timer_fd) = poller.timer_fd {
unsafe {
if let Some(ref timer_fd) = poller.timer_fd {
poller.add(
timer_fd.as_raw_fd(),
Event::none(crate::NOTIFY_KEY),
PollMode::Oneshot,
)?;
}
poller.add(
timer_fd.as_raw_fd(),
Event::none(crate::NOTIFY_KEY),
poller.event_fd.as_raw_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
writable: false,
},
PollMode::Oneshot,
)?;
}
poller.add(
poller.event_fd.as_raw_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
writable: false,
},
PollMode::Oneshot,
)?;
tracing::trace!(
epoll_fd = ?poller.epoll_fd.as_raw_fd(),
event_fd = ?poller.event_fd.as_raw_fd(),
@ -85,7 +88,12 @@ impl Poller {
}
/// Adds a new file descriptor.
pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
///
/// # Safety
///
/// The `fd` must be a valid file descriptor. The usual condition of remaining registered in
/// the `Poller` doesn't apply to `epoll`.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!(
"add",
epoll_fd = ?self.epoll_fd.as_raw_fd(),
@ -94,10 +102,10 @@ impl Poller {
);
let _enter = span.enter();
epoll::epoll_add(
epoll::add(
&self.epoll_fd,
unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) },
ev.key as u64,
epoll::EventData::new_u64(ev.key as u64),
epoll_flags(&ev, mode),
)?;
@ -105,7 +113,7 @@ impl Poller {
}
/// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!(
"modify",
epoll_fd = ?self.epoll_fd.as_raw_fd(),
@ -114,10 +122,10 @@ impl Poller {
);
let _enter = span.enter();
epoll::epoll_mod(
epoll::modify(
&self.epoll_fd,
unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) },
ev.key as u64,
fd,
epoll::EventData::new_u64(ev.key as u64),
epoll_flags(&ev, mode),
)?;
@ -125,7 +133,7 @@ impl Poller {
}
/// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> {
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
let span = tracing::trace_span!(
"delete",
epoll_fd = ?self.epoll_fd.as_raw_fd(),
@ -133,9 +141,7 @@ impl Poller {
);
let _enter = span.enter();
epoll::epoll_del(&self.epoll_fd, unsafe {
rustix::fd::BorrowedFd::borrow_raw(fd)
})?;
epoll::delete(&self.epoll_fd, fd)?;
Ok(())
}
@ -170,7 +176,7 @@ impl Poller {
// Set interest in timerfd.
self.modify(
timer_fd.as_raw_fd(),
timer_fd.as_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
@ -195,7 +201,7 @@ impl Poller {
};
// Wait for I/O events.
epoll::epoll_wait(&self.epoll_fd, &mut events.list, timeout_ms)?;
epoll::wait(&self.epoll_fd, &mut events.list, timeout_ms)?;
tracing::trace!(
epoll_fd = ?self.epoll_fd.as_raw_fd(),
res = ?events.list.len(),
@ -206,7 +212,7 @@ impl Poller {
let mut buf = [0u8; 8];
let _ = read(&self.event_fd, &mut buf);
self.modify(
self.event_fd.as_raw_fd(),
self.event_fd.as_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
@ -255,9 +261,9 @@ impl Drop for Poller {
let _enter = span.enter();
if let Some(timer_fd) = self.timer_fd.take() {
let _ = self.delete(timer_fd.as_raw_fd());
let _ = self.delete(timer_fd.as_fd());
}
let _ = self.delete(self.event_fd.as_raw_fd());
let _ = self.delete(self.event_fd.as_fd());
}
}
@ -310,10 +316,13 @@ impl Events {
/// Iterates over I/O events.
pub fn iter(&self) -> impl Iterator<Item = Event> + '_ {
self.list.iter().map(|(flags, data)| Event {
key: data as usize,
readable: flags.intersects(read_flags()),
writable: flags.intersects(write_flags()),
self.list.iter().map(|ev| {
let flags = ev.flags;
Event {
key: ev.data.u64() as usize,
readable: flags.intersects(read_flags()),
writable: flags.intersects(write_flags()),
}
})
}
}

View File

@ -42,7 +42,9 @@ use std::collections::hash_map::{Entry, HashMap};
use std::fmt;
use std::io;
use std::marker::PhantomPinned;
use std::os::windows::io::{AsHandle, AsRawHandle, BorrowedHandle, RawHandle, RawSocket};
use std::os::windows::io::{
AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket,
};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
@ -134,7 +136,16 @@ impl Poller {
}
/// Add a new source to the poller.
pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> {
///
/// # Safety
///
/// The socket must be a valid socket and must last until it is deleted.
pub(super) unsafe fn add(
&self,
socket: RawSocket,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
let span = tracing::trace_span!(
"add",
handle = ?self.port,
@ -192,7 +203,7 @@ impl Poller {
/// Update a source in the poller.
pub(super) fn modify(
&self,
socket: RawSocket,
socket: BorrowedSocket<'_>,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
@ -217,7 +228,7 @@ impl Poller {
let sources = lock!(self.sources.read());
sources
.get(&socket)
.get(&socket.as_raw_socket())
.cloned()
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
};
@ -231,7 +242,7 @@ impl Poller {
}
/// Delete a source from the poller.
pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> {
pub(super) fn delete(&self, socket: BorrowedSocket<'_>) -> io::Result<()> {
let span = tracing::trace_span!(
"remove",
handle = ?self.port,
@ -243,7 +254,7 @@ impl Poller {
let source = {
let mut sources = lock!(self.sources.write());
match sources.remove(&socket) {
match sources.remove(&socket.as_raw_socket()) {
Some(s) => s,
None => {
// If the source has already been removed, then we can just return.

View File

@ -1,11 +1,11 @@
//! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD).
use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use std::time::Duration;
use rustix::fd::OwnedFd;
use rustix::io::{fcntl_setfd, kqueue, Errno, FdFlags};
use rustix::event::kqueue;
use rustix::io::{fcntl_setfd, Errno, FdFlags};
use crate::{Event, PollMode};
@ -55,13 +55,17 @@ impl Poller {
}
/// Adds a new file descriptor.
pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
///
/// # Safety
///
/// The file descriptor must be valid and it must last until it is deleted.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
// File descriptors don't need to be added explicitly, so just modify the interest.
self.modify(fd, ev, mode)
self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
}
/// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = if !self.notify.has_fd(fd) {
let span = tracing::trace_span!(
"add",
@ -91,12 +95,12 @@ impl Poller {
// A list of changes for kqueue.
let changelist = [
kqueue::Event::new(
kqueue::EventFilter::Read(fd),
kqueue::EventFilter::Read(fd.as_raw_fd()),
read_flags | kqueue::EventFlags::RECEIPT,
ev.key as _,
),
kqueue::Event::new(
kqueue::EventFilter::Write(fd),
kqueue::EventFilter::Write(fd.as_raw_fd()),
write_flags | kqueue::EventFlags::RECEIPT,
ev.key as _,
),
@ -141,7 +145,7 @@ impl Poller {
}
/// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> {
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
// Simply delete interest in the file descriptor.
self.modify(fd, Event::none(0), PollMode::Oneshot)
}
@ -268,9 +272,9 @@ pub(crate) fn mode_to_flags(mode: PollMode) -> kqueue::EventFlags {
))]
mod notify {
use super::Poller;
use rustix::io::kqueue;
use rustix::event::kqueue;
use std::io;
use std::os::unix::io::RawFd;
use std::os::unix::io::BorrowedFd;
/// A notification pipe.
///
@ -335,7 +339,7 @@ mod notify {
}
/// Whether this raw file descriptor is associated with this pipe.
pub(super) fn has_fd(&self, _fd: RawFd) -> bool {
pub(super) fn has_fd(&self, _fd: BorrowedFd<'_>) -> bool {
false
}
}
@ -354,7 +358,7 @@ mod notify {
use crate::{Event, PollMode, NOTIFY_KEY};
use std::io::{self, prelude::*};
use std::os::unix::{
io::{AsRawFd, RawFd},
io::{AsFd, AsRawFd, BorrowedFd},
net::UnixStream,
};
@ -386,11 +390,13 @@ mod notify {
/// Registers this notification pipe in the `Poller`.
pub(super) fn register(&self, poller: &Poller) -> io::Result<()> {
// Register the read end of this pipe.
poller.add(
self.read_stream.as_raw_fd(),
Event::readable(NOTIFY_KEY),
PollMode::Oneshot,
)
unsafe {
poller.add(
self.read_stream.as_raw_fd(),
Event::readable(NOTIFY_KEY),
PollMode::Oneshot,
)
}
}
/// Reregister this notification pipe in the `Poller`.
@ -400,7 +406,7 @@ mod notify {
// Reregister the read end of this pipe.
poller.modify(
self.read_stream.as_raw_fd(),
self.read_stream.as_fd(),
Event::readable(NOTIFY_KEY),
PollMode::Oneshot,
)
@ -418,12 +424,12 @@ mod notify {
/// Deregisters this notification pipe from the `Poller`.
pub(super) fn deregister(&self, poller: &Poller) -> io::Result<()> {
// Deregister the read end of the pipe.
poller.delete(self.read_stream.as_raw_fd())
poller.delete(self.read_stream.as_fd())
}
/// Whether this raw file descriptor is associated with this pipe.
pub(super) fn has_fd(&self, fd: RawFd) -> bool {
self.read_stream.as_raw_fd() == fd
pub(super) fn has_fd(&self, fd: BorrowedFd<'_>) -> bool {
self.read_stream.as_raw_fd() == fd.as_raw_fd()
}
}
}

View File

@ -28,7 +28,9 @@
//!
//! // Create a poller and register interest in readability on the socket.
//! let poller = Poller::new()?;
//! poller.add(&socket, Event::readable(key))?;
//! unsafe {
//! poller.add(&socket, Event::readable(key))?;
//! }
//!
//! // The event loop.
//! let mut events = Vec::new();
@ -46,13 +48,15 @@
//! }
//! }
//! }
//!
//! poller.delete(&socket)?;
//! # std::io::Result::Ok(())
//! ```
#![cfg(feature = "std")]
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
#![allow(clippy::useless_conversion, clippy::unnecessary_cast)]
#![allow(clippy::useless_conversion, clippy::unnecessary_cast, unused_unsafe)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![doc(
html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
@ -273,8 +277,11 @@ impl Poller {
/// [`modify()`][`Poller::modify()`] again after an event is delivered if we're interested in
/// the next event of the same kind.
///
/// Don't forget to [`delete()`][`Poller::delete()`] the file descriptor or socket when it is
/// no longer used!
/// # Safety
///
/// The source must be [`delete()`]d from this `Poller` before it is dropped.
///
/// [`delete()`]: Poller::delete
///
/// # Errors
///
@ -295,10 +302,13 @@ impl Poller {
/// let key = 7;
///
/// let poller = Poller::new()?;
/// poller.add(&source, Event::all(key))?;
/// unsafe {
/// poller.add(&source, Event::all(key))?;
/// }
/// poller.delete(&source)?;
/// # std::io::Result::Ok(())
/// ```
pub fn add(&self, source: impl Source, interest: Event) -> io::Result<()> {
pub unsafe fn add(&self, source: impl AsRawSource, interest: Event) -> io::Result<()> {
self.add_with_mode(source, interest, PollMode::Oneshot)
}
@ -307,13 +317,19 @@ impl Poller {
/// This is identical to the `add()` function, but allows specifying the
/// polling mode to use for this socket.
///
/// # Safety
///
/// The source must be [`delete()`]d from this `Poller` before it is dropped.
///
/// [`delete()`]: Poller::delete
///
/// # Errors
///
/// If the operating system does not support the specified mode, this function
/// will return an error.
pub fn add_with_mode(
pub unsafe fn add_with_mode(
&self,
source: impl Source,
source: impl AsRawSource,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
@ -354,7 +370,7 @@ impl Poller {
/// # let source = std::net::TcpListener::bind("127.0.0.1:0")?;
/// # let key = 7;
/// # let poller = Poller::new()?;
/// # poller.add(&source, Event::none(key))?;
/// # unsafe { poller.add(&source, Event::none(key))?; }
/// poller.modify(&source, Event::all(key))?;
/// # std::io::Result::Ok(())
/// ```
@ -366,8 +382,9 @@ impl Poller {
/// # let source = std::net::TcpListener::bind("127.0.0.1:0")?;
/// # let key = 7;
/// # let poller = Poller::new()?;
/// # poller.add(&source, Event::none(key))?;
/// # unsafe { poller.add(&source, Event::none(key))?; }
/// poller.modify(&source, Event::readable(key))?;
/// # poller.delete(&source)?;
/// # std::io::Result::Ok(())
/// ```
///
@ -378,8 +395,9 @@ impl Poller {
/// # let poller = Poller::new()?;
/// # let key = 7;
/// # let source = std::net::TcpListener::bind("127.0.0.1:0")?;
/// # poller.add(&source, Event::none(key))?;
/// # unsafe { poller.add(&source, Event::none(key))? };
/// poller.modify(&source, Event::writable(key))?;
/// # poller.delete(&source)?;
/// # std::io::Result::Ok(())
/// ```
///
@ -390,11 +408,12 @@ impl Poller {
/// # let source = std::net::TcpListener::bind("127.0.0.1:0")?;
/// # let key = 7;
/// # let poller = Poller::new()?;
/// # poller.add(&source, Event::none(key))?;
/// # unsafe { poller.add(&source, Event::none(key))?; }
/// poller.modify(&source, Event::none(key))?;
/// # poller.delete(&source)?;
/// # std::io::Result::Ok(())
/// ```
pub fn modify(&self, source: impl Source, interest: Event) -> io::Result<()> {
pub fn modify(&self, source: impl AsSource, interest: Event) -> io::Result<()> {
self.modify_with_mode(source, interest, PollMode::Oneshot)
}
@ -415,7 +434,7 @@ impl Poller {
/// an error.
pub fn modify_with_mode(
&self,
source: impl Source,
source: impl AsSource,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
@ -425,7 +444,7 @@ impl Poller {
"the key is not allowed to be `usize::MAX`",
));
}
self.poller.modify(source.raw(), interest, mode)
self.poller.modify(source.source(), interest, mode)
}
/// Removes a file descriptor or socket from the poller.
@ -444,12 +463,12 @@ impl Poller {
/// let key = 7;
///
/// let poller = Poller::new()?;
/// poller.add(&socket, Event::all(key))?;
/// unsafe { poller.add(&socket, Event::all(key))?; }
/// poller.delete(&socket)?;
/// # std::io::Result::Ok(())
/// ```
pub fn delete(&self, source: impl Source) -> io::Result<()> {
self.poller.delete(source.raw())
pub fn delete(&self, source: impl AsSource) -> io::Result<()> {
self.poller.delete(source.source())
}
/// Waits for at least one I/O event and returns the number of new events.
@ -482,10 +501,13 @@ impl Poller {
/// let key = 7;
///
/// let poller = Poller::new()?;
/// poller.add(&socket, Event::all(key))?;
/// unsafe {
/// poller.add(&socket, Event::all(key))?;
/// }
///
/// let mut events = Vec::new();
/// let n = poller.wait(&mut events, Some(Duration::from_secs(1)))?;
/// poller.delete(&socket)?;
/// # std::io::Result::Ok(())
/// ```
pub fn wait(&self, events: &mut Vec<Event>, timeout: Option<Duration>) -> io::Result<usize> {
@ -624,45 +646,65 @@ impl fmt::Debug for Poller {
cfg_if! {
if #[cfg(unix)] {
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::io::{AsRawFd, RawFd, AsFd, BorrowedFd};
/// A [`RawFd`] or a reference to a type implementing [`AsRawFd`].
pub trait Source {
/// Returns the [`RawFd`] for this I/O object.
/// A resource with a raw file descriptor.
pub trait AsRawSource {
/// Returns the raw file descriptor.
fn raw(&self) -> RawFd;
}
impl Source for RawFd {
fn raw(&self) -> RawFd {
*self
}
}
impl<T: AsRawFd> Source for &T {
impl<T: AsRawFd> AsRawSource for &T {
fn raw(&self) -> RawFd {
self.as_raw_fd()
}
}
} else if #[cfg(windows)] {
use std::os::windows::io::{AsRawSocket, RawSocket};
/// A [`RawSocket`] or a reference to a type implementing [`AsRawSocket`].
pub trait Source {
/// Returns the [`RawSocket`] for this I/O object.
impl AsRawSource for RawFd {
fn raw(&self) -> RawFd {
*self
}
}
/// A resource with a borrowed file descriptor.
pub trait AsSource: AsFd {
/// Returns the borrowed file descriptor.
fn source(&self) -> BorrowedFd<'_> {
self.as_fd()
}
}
impl<T: AsFd> AsSource for T {}
} else if #[cfg(windows)] {
use std::os::windows::io::{AsRawSocket, RawSocket, AsSocket, BorrowedSocket};
/// A resource with a raw socket.
pub trait AsRawSource {
/// Returns the raw socket.
fn raw(&self) -> RawSocket;
}
impl Source for RawSocket {
impl<T: AsRawSocket> AsRawSource for &T {
fn raw(&self) -> RawSocket {
self.as_raw_socket()
}
}
impl AsRawSource for RawSocket {
fn raw(&self) -> RawSocket {
*self
}
}
impl<T: AsRawSocket> Source for &T {
fn raw(&self) -> RawSocket {
self.as_raw_socket()
/// A resource with a borrowed socket.
pub trait AsSource: AsSocket {
/// Returns the borrowed socket.
fn source(&self) -> BorrowedSocket<'_> {
self.as_socket()
}
}
impl<T: AsSocket> AsSource for T {}
}
}

View File

@ -7,7 +7,7 @@ use std::io;
use std::process::Child;
use std::time::Duration;
use rustix::io::kqueue;
use rustix::event::kqueue;
use super::__private::PollerSealed;
use __private::FilterSealed;
@ -238,7 +238,7 @@ unsafe impl FilterSealed for Timer {
impl Filter for Timer {}
mod __private {
use rustix::io::kqueue;
use rustix::event::kqueue;
#[doc(hidden)]
pub unsafe trait FilterSealed {

View File

@ -7,12 +7,11 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Condvar, Mutex};
use std::time::{Duration, Instant};
use rustix::event::{poll, PollFd, PollFlags};
use rustix::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd};
use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags};
use rustix::io::{
fcntl_getfd, fcntl_setfd, pipe, pipe_with, poll, read, write, FdFlags, PipeFlags, PollFd,
PollFlags,
};
use rustix::io::{fcntl_getfd, fcntl_setfd, read, write, FdFlags};
use rustix::pipe::{pipe, pipe_with, PipeFlags};
// std::os::unix doesn't exist on Fuchsia
type RawFd = std::os::raw::c_int;
@ -158,7 +157,7 @@ impl Poller {
}
/// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!(
"modify",
notify_read = ?self.notify_read,
@ -168,11 +167,19 @@ impl Poller {
let _enter = span.enter();
self.modify_fds(|fds| {
let data = fds.fd_data.get_mut(&fd).ok_or(io::ErrorKind::NotFound)?;
let data = fds
.fd_data
.get_mut(&fd.as_raw_fd())
.ok_or(io::ErrorKind::NotFound)?;
data.key = ev.key;
let poll_fds_index = data.poll_fds_index;
fds.poll_fds[poll_fds_index] =
PollFd::from_borrowed_fd(unsafe { BorrowedFd::borrow_raw(fd) }, poll_events(ev));
// SAFETY: This is essentially transmuting a `PollFd<'a>` to a `PollFd<'static>`, which
// only works if it's removed in time with `delete()`.
fds.poll_fds[poll_fds_index] = PollFd::from_borrowed_fd(
unsafe { BorrowedFd::borrow_raw(fd.as_raw_fd()) },
poll_events(ev),
);
data.remove = cvt_mode_as_remove(mode)?;
Ok(())
@ -180,7 +187,7 @@ impl Poller {
}
/// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> {
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
let span = tracing::trace_span!(
"delete",
notify_read = ?self.notify_read,
@ -189,7 +196,10 @@ impl Poller {
let _enter = span.enter();
self.modify_fds(|fds| {
let data = fds.fd_data.remove(&fd).ok_or(io::ErrorKind::NotFound)?;
let data = fds
.fd_data
.remove(&fd.as_raw_fd())
.ok_or(io::ErrorKind::NotFound)?;
fds.poll_fds.swap_remove(data.poll_fds_index);
if let Some(swapped_pollfd) = fds.poll_fds.get(data.poll_fds_index) {
fds.fd_data

View File

@ -4,8 +4,9 @@ use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::time::Duration;
use rustix::event::{port, PollFlags};
use rustix::fd::OwnedFd;
use rustix::io::{fcntl_getfd, fcntl_setfd, port, FdFlags, PollFlags};
use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
use crate::{Event, PollMode};
@ -42,13 +43,17 @@ impl Poller {
}
/// Adds a file descriptor.
pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
///
/// # Safety
///
/// The `fd` must be a valid file descriptor and it must last until it is deleted.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
// File descriptors don't need to be added explicitly, so just modify the interest.
self.modify(fd, ev, mode)
self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
}
/// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!(
"modify",
port_fd = ?self.port_fd.as_raw_fd(),
@ -79,7 +84,7 @@ impl Poller {
}
/// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> {
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
let span = tracing::trace_span!(
"delete",
port_fd = ?self.port_fd.as_raw_fd(),

View File

@ -13,20 +13,25 @@ fn concurrent_add() -> io::Result<()> {
let mut events = Vec::new();
Parallel::new()
let result = Parallel::new()
.add(|| {
poller.wait(&mut events, None)?;
Ok(())
})
.add(|| {
thread::sleep(Duration::from_millis(100));
poller.add(&reader, Event::readable(0))?;
unsafe {
poller.add(&reader, Event::readable(0))?;
}
writer.write_all(&[1])?;
Ok(())
})
.run()
.into_iter()
.collect::<io::Result<()>>()?;
.collect::<io::Result<()>>();
poller.delete(&reader)?;
result?;
assert_eq!(events, [Event::readable(0)]);
@ -37,7 +42,9 @@ fn concurrent_add() -> io::Result<()> {
fn concurrent_modify() -> io::Result<()> {
let (reader, mut writer) = tcp_pair()?;
let poller = Poller::new()?;
poller.add(&reader, Event::none(0))?;
unsafe {
poller.add(&reader, Event::none(0))?;
}
let mut events = Vec::new();

View File

@ -7,7 +7,9 @@ use std::time::Duration;
fn basic_io() {
let poller = Poller::new().unwrap();
let (read, mut write) = tcp_pair().unwrap();
poller.add(&read, Event::readable(1)).unwrap();
unsafe {
poller.add(&read, Event::readable(1)).unwrap();
}
// Nothing should be available at first.
let mut events = vec![];
@ -28,6 +30,8 @@ fn basic_io() {
1
);
assert_eq!(&*events, &[Event::readable(1)]);
poller.delete(&read).unwrap();
}
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {

View File

@ -20,7 +20,9 @@ fn many_connections() {
let poller = polling::Poller::new().unwrap();
for (i, reader, _) in connections.iter() {
poller.add(reader, polling::Event::readable(*i)).unwrap();
unsafe {
poller.add(reader, polling::Event::readable(*i)).unwrap();
}
}
let mut events = vec![];

View File

@ -16,8 +16,7 @@ fn level_triggered() {
// Create our poller and register our streams.
let poller = Poller::new().unwrap();
if poller
.add_with_mode(&reader, Event::readable(reader_token), PollMode::Level)
if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Level) }
.is_err()
{
// Only panic if we're on a platform that should support level mode.
@ -92,8 +91,7 @@ fn edge_triggered() {
// Create our poller and register our streams.
let poller = Poller::new().unwrap();
if poller
.add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge)
if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge) }
.is_err()
{
// Only panic if we're on a platform that should support level mode.
@ -170,13 +168,14 @@ fn edge_oneshot_triggered() {
// Create our poller and register our streams.
let poller = Poller::new().unwrap();
if poller
.add_with_mode(
if unsafe {
poller.add_with_mode(
&reader,
Event::readable(reader_token),
PollMode::EdgeOneshot,
)
.is_err()
}
.is_err()
{
// Only panic if we're on a platform that should support level mode.
cfg_if::cfg_if! {