mirror of https://github.com/stjepang/smol
add fn connect_timeout
This commit is contained in:
parent
c311b6897f
commit
1fa5cfa898
|
@ -7,6 +7,7 @@
|
|||
use std::future::Future;
|
||||
use std::io::{self, IoSlice, IoSliceMut, Read, Write};
|
||||
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
|
||||
use std::time::Duration;
|
||||
#[cfg(windows)]
|
||||
use std::os::windows::io::{AsRawSocket, IntoRawSocket, RawSocket};
|
||||
use std::pin::Pin;
|
||||
|
@ -710,6 +711,73 @@ impl Async<TcpStream> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Creates a TCP connection to the specified address with a timeout.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// use smol::Async;
|
||||
/// use std::net::TcpStream;
|
||||
/// use std::time::Duration;
|
||||
///
|
||||
/// # smol::run(async {
|
||||
/// let stream = Async::<TcpStream>::connect_timeout("example.com:80", Duration::from_secs(5)).await?;
|
||||
/// # std::io::Result::Ok(()) });
|
||||
/// ```
|
||||
pub async fn connect_timeout<A: ToString>(addr: A, timeout: Duration) -> io::Result<Async<TcpStream>> {
|
||||
let addr = addr.to_string();
|
||||
let addr = Task::blocking(async move {
|
||||
addr.to_socket_addrs()?.next().ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "could not resolve the address")
|
||||
})
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Create a socket.
|
||||
let domain = if addr.is_ipv6() {
|
||||
Domain::ipv6()
|
||||
} else {
|
||||
Domain::ipv4()
|
||||
};
|
||||
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?;
|
||||
|
||||
// Begin async connect and ignore the inevitable "in progress" error.
|
||||
socket.set_nonblocking(true)?;
|
||||
socket.connect_timeout(&addr.into(), timeout.into()).or_else(|err| {
|
||||
// Check for EINPROGRESS on Unix and WSAEWOULDBLOCK on Windows.
|
||||
#[cfg(unix)]
|
||||
let in_progress = err.raw_os_error() == Some(nix::libc::EINPROGRESS);
|
||||
#[cfg(windows)]
|
||||
let in_progress = err.kind() == io::ErrorKind::WouldBlock;
|
||||
|
||||
// If connect results with an "in progress" error, that's not an error.
|
||||
if in_progress {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
})?;
|
||||
let stream = Async::new(socket.into_tcp_stream())?;
|
||||
|
||||
// Waits for connect to complete.
|
||||
let wait_connect = |mut stream: &TcpStream| match stream.write(&[]) {
|
||||
Err(err) if err.kind() == io::ErrorKind::NotConnected => match stream.take_error()? {
|
||||
Some(err) => Err(err),
|
||||
None => Err(io::ErrorKind::WouldBlock.into()),
|
||||
},
|
||||
res => res.map(|_| ()),
|
||||
};
|
||||
|
||||
// The stream becomes writable when connected.
|
||||
match stream.write_with(|io| wait_connect(io)).await {
|
||||
Ok(()) => Ok(stream),
|
||||
Err(err) => match stream.get_ref().take_error()? {
|
||||
Some(err) => Err(err),
|
||||
None => Err(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads data from the stream without removing it from the buffer.
|
||||
///
|
||||
/// Returns the number of bytes read. Successive calls of this method read the same data.
|
||||
|
|
Loading…
Reference in New Issue