bugfix: Handle interrupts while polling

Previous, `Poller::wait` would bubble signal interruption error to the user.
However, this may be unexpected for simple use cases. Thus, this commit makes
it so, if `ErrorKind::Interrupted` is received by the underlying `wait()` call,
it clears the events and tries to wait again.

This also adds a test for this interruption written by @psychon.

Co-Authored-By: Uli Schlachter <psychon@users.noreply.github.com>
Signed-off-by: John Nunley <dev@notgull.net>
This commit is contained in:
John Nunley 2023-10-27 07:02:08 -07:00 committed by GitHub
parent 0575cbd4bc
commit b9ab821df1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 7 deletions

View File

@ -44,3 +44,7 @@ features = [
[dev-dependencies]
easy-parallel = "3.1.0"
fastrand = "2.0.0"
[target.'cfg(unix)'.dev-dependencies]
libc = "0.2"
signal-hook = "0.3.17"

View File

@ -70,7 +70,7 @@ use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::Duration;
use std::time::{Duration, Instant};
use cfg_if::cfg_if;
@ -651,14 +651,30 @@ impl Poller {
let _enter = span.enter();
if let Ok(_lock) = self.lock.try_lock() {
// Wait for I/O events.
self.poller.wait(&mut events.events, timeout)?;
let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout));
// Clear the notification, if any.
self.notified.swap(false, Ordering::SeqCst);
loop {
// Figure out how long to wait for.
let timeout =
deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
// Indicate number of events.
Ok(events.len())
// Wait for I/O events.
if let Err(e) = self.poller.wait(&mut events.events, timeout) {
// If the wait was interrupted by a signal, clear events and try again.
if e.kind() == io::ErrorKind::Interrupted {
events.clear();
continue;
} else {
return Err(e);
}
}
// Clear the notification, if any.
self.notified.swap(false, Ordering::SeqCst);
// Indicate number of events.
return Ok(events.len());
}
} else {
tracing::trace!("wait: skipping because another thread is already waiting on I/O");
Ok(0)

View File

@ -76,6 +76,53 @@ fn concurrent_modify() -> io::Result<()> {
Ok(())
}
#[cfg(unix)]
#[test]
fn concurrent_interruption() -> io::Result<()> {
struct MakeItSend<T>(T);
unsafe impl<T> Send for MakeItSend<T> {}
let (reader, _writer) = tcp_pair()?;
let poller = Poller::new()?;
unsafe {
poller.add(&reader, Event::none(0))?;
}
let mut events = Events::new();
let events_borrow = &mut events;
let (sender, receiver) = std::sync::mpsc::channel();
Parallel::new()
.add(move || {
// Register a signal handler so that the syscall is actually interrupted. A signal that
// is ignored by default does not cause an interrupted syscall.
signal_hook::flag::register(signal_hook::consts::signal::SIGURG, Default::default())?;
// Signal to the other thread how to send a signal to us
sender
.send(MakeItSend(unsafe { libc::pthread_self() }))
.unwrap();
poller.wait(events_borrow, Some(Duration::from_secs(1)))?;
Ok(())
})
.add(move || {
let MakeItSend(target_thread) = receiver.recv().unwrap();
thread::sleep(Duration::from_millis(100));
assert_eq!(0, unsafe {
libc::pthread_kill(target_thread, libc::SIGURG)
});
Ok(())
})
.run()
.into_iter()
.collect::<io::Result<()>>()?;
assert_eq!(events.len(), 0);
Ok(())
}
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let a = TcpStream::connect(listener.local_addr()?)?;