diff --git a/src/lib.rs b/src/lib.rs index 0101d03..d371b4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,19 +98,28 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result }}; } + // A guard that closes the file descriptor if an error occurs before the end. + let mut guard; + // On linux, we pass the `SOCK_CLOEXEC` flag to atomically create the socket and set it as // CLOEXEC. #[cfg(target_os = "linux")] - let fd = syscall!(socket( - family, - libc::SOCK_STREAM | libc::SOCK_CLOEXEC, - protocol - ))?; + let fd = { + let fd = syscall!(socket( + family, + libc::SOCK_STREAM | libc::SOCK_CLOEXEC, + protocol, + ))?; + guard = CallOnDrop(Some(move || drop(syscall!(close(fd))))); + fd + }; // On other systems, we first create the socket and then set it as CLOEXEC. #[cfg(not(target_os = "linux"))] let fd = { let fd = syscall!(socket(family, libc::SOCK_STREAM, protocol))?; + guard = CallOnDrop(Some(move || drop(syscall!(close(fd))))); + let flags = syscall!(fcntl(fd, libc::F_GETFD))? | libc::FD_CLOEXEC; syscall!(fcntl(fd, libc::F_SETFD, flags))?; @@ -140,6 +149,9 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result Err(err) => return Err(err), } + // Disarm the guard so that it doesn't close the file descriptor. + guard.0.take(); + Ok(fd) } @@ -313,6 +325,9 @@ fn tcp_connect(addr: SocketAddr) -> io::Result { socket => socket, }; + // Create a TCP stream now so that it closes the socket if an error occurs before the end. + let stream = TcpStream::from_raw_socket(socket as _); + // Set no inherit. if SetHandleInformation(socket as HANDLE, HANDLE_FLAG_INHERIT, 0) == 0 { return Err(io::Error::last_os_error()); @@ -333,6 +348,17 @@ fn tcp_connect(addr: SocketAddr) -> io::Result { }, } - Ok(TcpStream::from_raw_socket(socket as _)) + Ok(stream) + } +} + +/// Runs a closure when dropped. +struct CallOnDrop(Option); + +impl Drop for CallOnDrop { + fn drop(&mut self) { + if let Some(f) = self.0.take() { + f(); + } } }