mirror of https://github.com/smol-rs/polling
bugfix: Handle interrupts while polling
Previous, `Poller::wait` would bubble signal interruption error to the user. However, this may be unexpected for simple use cases. Thus, this commit makes it so, if `ErrorKind::Interrupted` is received by the underlying `wait()` call, it clears the events and tries to wait again. This also adds a test for this interruption written by @psychon. Co-Authored-By: Uli Schlachter <psychon@users.noreply.github.com> Signed-off-by: John Nunley <dev@notgull.net>
This commit is contained in:
parent
0575cbd4bc
commit
b9ab821df1
|
@ -44,3 +44,7 @@ features = [
|
|||
[dev-dependencies]
|
||||
easy-parallel = "3.1.0"
|
||||
fastrand = "2.0.0"
|
||||
|
||||
[target.'cfg(unix)'.dev-dependencies]
|
||||
libc = "0.2"
|
||||
signal-hook = "0.3.17"
|
||||
|
|
30
src/lib.rs
30
src/lib.rs
|
@ -70,7 +70,7 @@ use std::marker::PhantomData;
|
|||
use std::num::NonZeroUsize;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use cfg_if::cfg_if;
|
||||
|
||||
|
@ -651,14 +651,30 @@ impl Poller {
|
|||
let _enter = span.enter();
|
||||
|
||||
if let Ok(_lock) = self.lock.try_lock() {
|
||||
// Wait for I/O events.
|
||||
self.poller.wait(&mut events.events, timeout)?;
|
||||
let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout));
|
||||
|
||||
// Clear the notification, if any.
|
||||
self.notified.swap(false, Ordering::SeqCst);
|
||||
loop {
|
||||
// Figure out how long to wait for.
|
||||
let timeout =
|
||||
deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
|
||||
|
||||
// Indicate number of events.
|
||||
Ok(events.len())
|
||||
// Wait for I/O events.
|
||||
if let Err(e) = self.poller.wait(&mut events.events, timeout) {
|
||||
// If the wait was interrupted by a signal, clear events and try again.
|
||||
if e.kind() == io::ErrorKind::Interrupted {
|
||||
events.clear();
|
||||
continue;
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the notification, if any.
|
||||
self.notified.swap(false, Ordering::SeqCst);
|
||||
|
||||
// Indicate number of events.
|
||||
return Ok(events.len());
|
||||
}
|
||||
} else {
|
||||
tracing::trace!("wait: skipping because another thread is already waiting on I/O");
|
||||
Ok(0)
|
||||
|
|
|
@ -76,6 +76,53 @@ fn concurrent_modify() -> io::Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn concurrent_interruption() -> io::Result<()> {
|
||||
struct MakeItSend<T>(T);
|
||||
unsafe impl<T> Send for MakeItSend<T> {}
|
||||
|
||||
let (reader, _writer) = tcp_pair()?;
|
||||
let poller = Poller::new()?;
|
||||
unsafe {
|
||||
poller.add(&reader, Event::none(0))?;
|
||||
}
|
||||
|
||||
let mut events = Events::new();
|
||||
let events_borrow = &mut events;
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
|
||||
Parallel::new()
|
||||
.add(move || {
|
||||
// Register a signal handler so that the syscall is actually interrupted. A signal that
|
||||
// is ignored by default does not cause an interrupted syscall.
|
||||
signal_hook::flag::register(signal_hook::consts::signal::SIGURG, Default::default())?;
|
||||
|
||||
// Signal to the other thread how to send a signal to us
|
||||
sender
|
||||
.send(MakeItSend(unsafe { libc::pthread_self() }))
|
||||
.unwrap();
|
||||
|
||||
poller.wait(events_borrow, Some(Duration::from_secs(1)))?;
|
||||
Ok(())
|
||||
})
|
||||
.add(move || {
|
||||
let MakeItSend(target_thread) = receiver.recv().unwrap();
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
assert_eq!(0, unsafe {
|
||||
libc::pthread_kill(target_thread, libc::SIGURG)
|
||||
});
|
||||
Ok(())
|
||||
})
|
||||
.run()
|
||||
.into_iter()
|
||||
.collect::<io::Result<()>>()?;
|
||||
|
||||
assert_eq!(events.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let a = TcpStream::connect(listener.local_addr()?)?;
|
||||
|
|
Loading…
Reference in New Issue