Close the socket on error

This commit is contained in:
Stjepan Glavina 2020-09-29 19:24:55 +02:00
parent f1f84f4ad4
commit 5db2513e08
1 changed files with 32 additions and 6 deletions

View File

@ -98,19 +98,28 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
}}; }};
} }
// A guard that closes the file descriptor if an error occurs before the end.
let mut guard;
// On linux, we pass the `SOCK_CLOEXEC` flag to atomically create the socket and set it as // On linux, we pass the `SOCK_CLOEXEC` flag to atomically create the socket and set it as
// CLOEXEC. // CLOEXEC.
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let fd = syscall!(socket( let fd = {
family, let fd = syscall!(socket(
libc::SOCK_STREAM | libc::SOCK_CLOEXEC, family,
protocol libc::SOCK_STREAM | libc::SOCK_CLOEXEC,
))?; protocol,
))?;
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
fd
};
// On other systems, we first create the socket and then set it as CLOEXEC. // On other systems, we first create the socket and then set it as CLOEXEC.
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
let fd = { let fd = {
let fd = syscall!(socket(family, libc::SOCK_STREAM, protocol))?; let fd = syscall!(socket(family, libc::SOCK_STREAM, protocol))?;
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
let flags = syscall!(fcntl(fd, libc::F_GETFD))? | libc::FD_CLOEXEC; let flags = syscall!(fcntl(fd, libc::F_GETFD))? | libc::FD_CLOEXEC;
syscall!(fcntl(fd, libc::F_SETFD, flags))?; syscall!(fcntl(fd, libc::F_SETFD, flags))?;
@ -140,6 +149,9 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
Err(err) => return Err(err), Err(err) => return Err(err),
} }
// Disarm the guard so that it doesn't close the file descriptor.
guard.0.take();
Ok(fd) Ok(fd)
} }
@ -313,6 +325,9 @@ fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
socket => socket, socket => socket,
}; };
// Create a TCP stream now so that it closes the socket if an error occurs before the end.
let stream = TcpStream::from_raw_socket(socket as _);
// Set no inherit. // Set no inherit.
if SetHandleInformation(socket as HANDLE, HANDLE_FLAG_INHERIT, 0) == 0 { if SetHandleInformation(socket as HANDLE, HANDLE_FLAG_INHERIT, 0) == 0 {
return Err(io::Error::last_os_error()); return Err(io::Error::last_os_error());
@ -333,6 +348,17 @@ fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
}, },
} }
Ok(TcpStream::from_raw_socket(socket as _)) Ok(stream)
}
}
/// Runs a closure when dropped.
struct CallOnDrop<F: FnOnce()>(Option<F>);
impl<F: FnOnce()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
if let Some(f) = self.0.take() {
f();
}
} }
} }