mirror of https://github.com/smol-rs/nb-connect
Close the socket on error
This commit is contained in:
parent
f1f84f4ad4
commit
5db2513e08
38
src/lib.rs
38
src/lib.rs
|
@ -98,19 +98,28 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
|
|||
}};
|
||||
}
|
||||
|
||||
// A guard that closes the file descriptor if an error occurs before the end.
|
||||
let mut guard;
|
||||
|
||||
// On linux, we pass the `SOCK_CLOEXEC` flag to atomically create the socket and set it as
|
||||
// CLOEXEC.
|
||||
#[cfg(target_os = "linux")]
|
||||
let fd = syscall!(socket(
|
||||
family,
|
||||
libc::SOCK_STREAM | libc::SOCK_CLOEXEC,
|
||||
protocol
|
||||
))?;
|
||||
let fd = {
|
||||
let fd = syscall!(socket(
|
||||
family,
|
||||
libc::SOCK_STREAM | libc::SOCK_CLOEXEC,
|
||||
protocol,
|
||||
))?;
|
||||
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
|
||||
fd
|
||||
};
|
||||
|
||||
// On other systems, we first create the socket and then set it as CLOEXEC.
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
let fd = {
|
||||
let fd = syscall!(socket(family, libc::SOCK_STREAM, protocol))?;
|
||||
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
|
||||
|
||||
let flags = syscall!(fcntl(fd, libc::F_GETFD))? | libc::FD_CLOEXEC;
|
||||
syscall!(fcntl(fd, libc::F_SETFD, flags))?;
|
||||
|
||||
|
@ -140,6 +149,9 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
|
|||
Err(err) => return Err(err),
|
||||
}
|
||||
|
||||
// Disarm the guard so that it doesn't close the file descriptor.
|
||||
guard.0.take();
|
||||
|
||||
Ok(fd)
|
||||
}
|
||||
|
||||
|
@ -313,6 +325,9 @@ fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
|
|||
socket => socket,
|
||||
};
|
||||
|
||||
// Create a TCP stream now so that it closes the socket if an error occurs before the end.
|
||||
let stream = TcpStream::from_raw_socket(socket as _);
|
||||
|
||||
// Set no inherit.
|
||||
if SetHandleInformation(socket as HANDLE, HANDLE_FLAG_INHERIT, 0) == 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
|
@ -333,6 +348,17 @@ fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
|
|||
},
|
||||
}
|
||||
|
||||
Ok(TcpStream::from_raw_socket(socket as _))
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs a closure when dropped.
|
||||
struct CallOnDrop<F: FnOnce()>(Option<F>);
|
||||
|
||||
impl<F: FnOnce()> Drop for CallOnDrop<F> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(f) = self.0.take() {
|
||||
f();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue