Make poll's poller modifications interrupt `wait`

This commit is contained in:
Koxiaet 2020-12-18 14:18:03 +00:00
parent 0f2f6ed15a
commit e0789a8ee0
2 changed files with 168 additions and 79 deletions

View File

@ -74,10 +74,10 @@ macro_rules! syscall {
}
cfg_if! {
if #[cfg(any(target_os = "linux", target_os = "android"))] {
/*if #[cfg(any(target_os = "linux", target_os = "android"))] {
mod epoll;
use epoll as sys;
} else if #[cfg(any(
} else*/ if #[cfg(any(
target_os = "illumos",
target_os = "solaris",
))] {

View File

@ -3,8 +3,9 @@
use std::collections::HashMap;
use std::convert::TryInto;
use std::io;
use std::sync::Mutex;
use std::time::Duration;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Condvar, Mutex};
use std::time::{Duration, Instant};
// std::os::unix doesn't exist on Fuchsia
use libc::c_int as RawFd;
@ -19,11 +20,27 @@ const REMOVE_FD: RawFd = -2;
pub struct Poller {
/// File descriptors to poll.
fds: Mutex<Fds>,
/// The file descriptor of the read half of the notify pipe. This is also stored as the first
/// file descriptor in `fds.poll_fds`.
notify_read: RawFd,
/// The file descriptor of the write half of the notify pipe.
///
/// Data is written to this to wake up the current instance of `wait`, which can occur when the
/// user notifies it (in which case `notified` would have been set) or when an operation needs
/// to occur (in which case `waiting_operations` would have been incremented).
notify_write: RawFd,
/// The number of operations (`add`, `modify` or `delete`) that are currently waiting on the
/// mutex to become free. When this is nonzero, `wait` must be suspended until it reaches zero
/// again.
waiting_operations: AtomicUsize,
/// The condition variable that gets notified when `waiting_operations` reaches zero. This is
/// used with the `fds` mutex.
operations_complete: Condvar,
/// Whether `wait` has been notified by the user.
notified: AtomicBool,
}
/// The file descriptors to poll in a `Poller`.
@ -76,6 +93,9 @@ impl Poller {
}),
notify_read: notify_pipe[0],
notify_write: notify_pipe[1],
waiting_operations: AtomicUsize::new(0),
operations_complete: Condvar::new(),
notified: AtomicBool::new(false),
})
}
@ -85,112 +105,155 @@ impl Poller {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
let mut fds = self.fds.lock().unwrap();
self.modify_fds(|fds| {
if fds.fd_data.contains_key(&fd) {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}
if fds.fd_data.contains_key(&fd) {
return Err(io::Error::from(io::ErrorKind::AlreadyExists));
}
let poll_fds_index = fds.poll_fds.len();
fds.fd_data.insert(
fd,
FdData {
poll_fds_index,
key: ev.key,
},
);
let poll_fds_index = fds.poll_fds.len();
fds.fd_data.insert(
fd,
FdData {
poll_fds_index,
key: ev.key,
},
);
fds.poll_fds.push(libc::pollfd {
fd,
events: poll_events(ev),
revents: 0,
});
fds.poll_fds.push(libc::pollfd {
fd,
events: poll_events(ev),
revents: 0,
});
Ok(())
Ok(())
})
}
/// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event) -> io::Result<()> {
let mut fds = self.fds.lock().unwrap();
self.modify_fds(|fds| {
let data = fds.fd_data.get_mut(&fd).ok_or(io::ErrorKind::NotFound)?;
data.key = ev.key;
let poll_fds_index = data.poll_fds_index;
fds.poll_fds[poll_fds_index].events = poll_events(ev);
let data = fds.fd_data.get_mut(&fd).ok_or(io::ErrorKind::NotFound)?;
data.key = ev.key;
let poll_fds_index = data.poll_fds_index;
fds.poll_fds[poll_fds_index].events = poll_events(ev);
Ok(())
Ok(())
})
}
/// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> {
let mut fds = self.fds.lock().unwrap();
self.modify_fds(|fds| {
let data = fds.fd_data.remove(&fd).ok_or(io::ErrorKind::NotFound)?;
fds.poll_fds[data.poll_fds_index].fd = REMOVE_FD;
let data = fds.fd_data.remove(&fd).ok_or(io::ErrorKind::NotFound)?;
fds.poll_fds[data.poll_fds_index].fd = REMOVE_FD;
Ok(())
Ok(())
})
}
/// Waits for I/O events with an optional timeout.
pub fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
let deadline = timeout.map(|t| Instant::now() + t);
events.inner.clear();
let timeout_ms = timeout
.map(|timeout| {
// Round up to a whole millisecond.
let mut ms = timeout.as_millis().try_into().unwrap_or(std::u64::MAX);
if Duration::from_millis(ms) < timeout {
ms += 1;
}
ms.try_into().unwrap_or(std::i32::MAX)
})
.unwrap_or(-1);
let mut fds = self.fds.lock().unwrap();
let fds = &mut *fds;
// Remove all fds that have been marked to be removed.
fds.poll_fds.retain(|poll_fd| poll_fd.fd != REMOVE_FD);
let num_events = loop {
match syscall!(poll(
fds.poll_fds.as_mut_ptr(),
fds.poll_fds.len() as u64,
timeout_ms,
)) {
Ok(num_events) => break num_events as usize,
// EAGAIN is translated into WouldBlock, and EWOULDBLOCK cannot be returned by
// poll.
Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => return Err(e),
};
};
// Store any events that occured and remove interest.
events.inner.reserve(num_events);
for fd_data in fds.fd_data.values_mut() {
let mut poll_fd = fds.poll_fds[fd_data.poll_fds_index];
if poll_fd.revents != 0 {
events.inner.push(Event {
key: fd_data.key,
readable: poll_fd.revents & READ_REVENTS != 0,
writable: poll_fd.revents & WRITE_REVENTS != 0,
});
poll_fd.events = 0;
loop {
// Complete all current operations.
while self.waiting_operations.load(Ordering::SeqCst) != 0 {
fds = self.operations_complete.wait(fds).unwrap();
}
}
// Read all notifications.
while syscall!(read(self.notify_read, &mut [0; 64] as *mut _ as *mut _, 64)).is_ok() {}
// Remove all fds that have been marked to be removed.
fds.poll_fds.retain(|poll_fd| poll_fd.fd != REMOVE_FD);
// Perform the poll.
let num_events = poll(&mut fds.poll_fds, deadline)?;
let notified = fds.poll_fds[0].revents != 0;
let num_fd_events = if notified { num_events - 1 } else { num_events };
// Read all notifications.
if notified {
while syscall!(read(self.notify_read, &mut [0; 64] as *mut _ as *mut _, 64)).is_ok()
{
}
}
// If the only event that occurred during polling was notification and it wasn't to
// exit, another thread is trying to perform an operation on the fds. Continue the
// loop.
if !self.notified.swap(false, Ordering::SeqCst) && num_fd_events == 0 && notified {
continue;
}
// Store the events if there were any.
if num_fd_events > 0 {
let fds = &mut *fds;
events.inner.reserve(num_fd_events);
for fd_data in fds.fd_data.values_mut() {
let mut poll_fd = fds.poll_fds[fd_data.poll_fds_index];
if poll_fd.revents != 0 {
// Store event
events.inner.push(Event {
key: fd_data.key,
readable: poll_fd.revents & READ_REVENTS != 0,
writable: poll_fd.revents & WRITE_REVENTS != 0,
});
// Remove interest
poll_fd.events = 0;
if events.inner.len() == num_fd_events {
break;
}
}
}
}
break;
}
Ok(())
}
/// Sends a notification to wake up the current or next `wait()` call.
pub fn notify(&self) -> io::Result<()> {
syscall!(write(self.notify_write, &0_u8 as *const _ as *const _, 1))?;
if !self.notified.swap(true, Ordering::SeqCst) {
self.notify_inner()?;
}
Ok(())
}
/// Perform a modification on `fds`, interrupting the current caller of `wait` if it's running.
fn modify_fds(&self, f: impl FnOnce(&mut Fds) -> io::Result<()>) -> io::Result<()> {
self.waiting_operations.fetch_add(1, Ordering::SeqCst);
// Wake up the current caller of `wait` if there is one.
let sent_notification = self.notify_inner().is_ok();
let mut fds = self.fds.lock().unwrap();
// If there was no caller of `wait` our byte was not removed from the pipe, so attempt to
// remove one byte from the pipe.
if sent_notification {
let _ = syscall!(read(self.notify_read, &mut [0; 1] as *mut _ as *mut _, 1));
}
let res = f(&mut *fds);
if self.waiting_operations.fetch_sub(1, Ordering::SeqCst) == 1 {
self.operations_complete.notify_one();
}
res
}
/// Wake the current thread that is calling `wait`.
fn notify_inner(&self) -> io::Result<()> {
syscall!(write(self.notify_write, &0_u8 as *const _ as *const _, 1)).map(drop)
}
}
impl Drop for Poller {
@ -236,3 +299,29 @@ impl Events {
self.inner.iter().copied()
}
}
/// Helper function to call poll.
fn poll(fds: &mut [libc::pollfd], deadline: Option<Instant>) -> io::Result<usize> {
loop {
// Convert the timeout to milliseconds.
let timeout_ms = deadline
.map(|deadline| {
let timeout = deadline.saturating_duration_since(Instant::now());
// Round up to a whole millisecond.
let mut ms = timeout.as_millis().try_into().unwrap_or(std::u64::MAX);
if Duration::from_millis(ms) < timeout {
ms += 1;
}
ms.try_into().unwrap_or(std::i32::MAX)
})
.unwrap_or(-1);
match syscall!(poll(fds.as_mut_ptr(), fds.len() as u64, timeout_ms,)) {
Ok(num_events) => break Ok(num_events as usize),
// poll returns EAGAIN if we can retry it.
Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => return Err(e),
}
}
}