add fn connect_timeout

This commit is contained in:
ovalek 2020-05-24 02:39:25 +02:00
parent c311b6897f
commit 1fa5cfa898
1 changed files with 68 additions and 0 deletions

View File

@ -7,6 +7,7 @@
use std::future::Future;
use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
use std::time::Duration;
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, IntoRawSocket, RawSocket};
use std::pin::Pin;
@ -710,6 +711,73 @@ impl Async<TcpStream> {
}
}
/// Creates a TCP connection to the specified address with a timeout.
///
/// # Examples
///
/// ```no_run
/// use smol::Async;
/// use std::net::TcpStream;
/// use std::time::Duration;
///
/// # smol::run(async {
/// let stream = Async::<TcpStream>::connect_timeout("example.com:80", Duration::from_secs(5)).await?;
/// # std::io::Result::Ok(()) });
/// ```
pub async fn connect_timeout<A: ToString>(addr: A, timeout: Duration) -> io::Result<Async<TcpStream>> {
let addr = addr.to_string();
let addr = Task::blocking(async move {
addr.to_socket_addrs()?.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "could not resolve the address")
})
})
.await?;
// Create a socket.
let domain = if addr.is_ipv6() {
Domain::ipv6()
} else {
Domain::ipv4()
};
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?;
// Begin async connect and ignore the inevitable "in progress" error.
socket.set_nonblocking(true)?;
socket.connect_timeout(&addr.into(), timeout.into()).or_else(|err| {
// Check for EINPROGRESS on Unix and WSAEWOULDBLOCK on Windows.
#[cfg(unix)]
let in_progress = err.raw_os_error() == Some(nix::libc::EINPROGRESS);
#[cfg(windows)]
let in_progress = err.kind() == io::ErrorKind::WouldBlock;
// If connect results with an "in progress" error, that's not an error.
if in_progress {
Ok(())
} else {
Err(err)
}
})?;
let stream = Async::new(socket.into_tcp_stream())?;
// Waits for connect to complete.
let wait_connect = |mut stream: &TcpStream| match stream.write(&[]) {
Err(err) if err.kind() == io::ErrorKind::NotConnected => match stream.take_error()? {
Some(err) => Err(err),
None => Err(io::ErrorKind::WouldBlock.into()),
},
res => res.map(|_| ()),
};
// The stream becomes writable when connected.
match stream.write_with(|io| wait_connect(io)).await {
Ok(()) => Ok(stream),
Err(err) => match stream.get_ref().take_error()? {
Some(err) => Err(err),
None => Err(err),
},
}
}
/// Reads data from the stream without removing it from the buffer.
///
/// Returns the number of bytes read. Successive calls of this method read the same data.