mirror of https://github.com/smol-rs/nb-connect
Close the socket on error
This commit is contained in:
parent
f1f84f4ad4
commit
5db2513e08
38
src/lib.rs
38
src/lib.rs
|
@ -98,19 +98,28 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A guard that closes the file descriptor if an error occurs before the end.
|
||||||
|
let mut guard;
|
||||||
|
|
||||||
// On linux, we pass the `SOCK_CLOEXEC` flag to atomically create the socket and set it as
|
// On linux, we pass the `SOCK_CLOEXEC` flag to atomically create the socket and set it as
|
||||||
// CLOEXEC.
|
// CLOEXEC.
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
let fd = syscall!(socket(
|
let fd = {
|
||||||
family,
|
let fd = syscall!(socket(
|
||||||
libc::SOCK_STREAM | libc::SOCK_CLOEXEC,
|
family,
|
||||||
protocol
|
libc::SOCK_STREAM | libc::SOCK_CLOEXEC,
|
||||||
))?;
|
protocol,
|
||||||
|
))?;
|
||||||
|
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
|
||||||
|
fd
|
||||||
|
};
|
||||||
|
|
||||||
// On other systems, we first create the socket and then set it as CLOEXEC.
|
// On other systems, we first create the socket and then set it as CLOEXEC.
|
||||||
#[cfg(not(target_os = "linux"))]
|
#[cfg(not(target_os = "linux"))]
|
||||||
let fd = {
|
let fd = {
|
||||||
let fd = syscall!(socket(family, libc::SOCK_STREAM, protocol))?;
|
let fd = syscall!(socket(family, libc::SOCK_STREAM, protocol))?;
|
||||||
|
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
|
||||||
|
|
||||||
let flags = syscall!(fcntl(fd, libc::F_GETFD))? | libc::FD_CLOEXEC;
|
let flags = syscall!(fcntl(fd, libc::F_GETFD))? | libc::FD_CLOEXEC;
|
||||||
syscall!(fcntl(fd, libc::F_SETFD, flags))?;
|
syscall!(fcntl(fd, libc::F_SETFD, flags))?;
|
||||||
|
|
||||||
|
@ -140,6 +149,9 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
|
||||||
Err(err) => return Err(err),
|
Err(err) => return Err(err),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disarm the guard so that it doesn't close the file descriptor.
|
||||||
|
guard.0.take();
|
||||||
|
|
||||||
Ok(fd)
|
Ok(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -313,6 +325,9 @@ fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
|
||||||
socket => socket,
|
socket => socket,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Create a TCP stream now so that it closes the socket if an error occurs before the end.
|
||||||
|
let stream = TcpStream::from_raw_socket(socket as _);
|
||||||
|
|
||||||
// Set no inherit.
|
// Set no inherit.
|
||||||
if SetHandleInformation(socket as HANDLE, HANDLE_FLAG_INHERIT, 0) == 0 {
|
if SetHandleInformation(socket as HANDLE, HANDLE_FLAG_INHERIT, 0) == 0 {
|
||||||
return Err(io::Error::last_os_error());
|
return Err(io::Error::last_os_error());
|
||||||
|
@ -333,6 +348,17 @@ fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(TcpStream::from_raw_socket(socket as _))
|
Ok(stream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs a closure when dropped.
|
||||||
|
struct CallOnDrop<F: FnOnce()>(Option<F>);
|
||||||
|
|
||||||
|
impl<F: FnOnce()> Drop for CallOnDrop<F> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(f) = self.0.take() {
|
||||||
|
f();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue