Close the socket on error

This commit is contained in:
Stjepan Glavina 2020-09-29 19:24:55 +02:00
parent f1f84f4ad4
commit 5db2513e08
1 changed files with 32 additions and 6 deletions

View File

@ -98,19 +98,28 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
}};
}
// A guard that closes the file descriptor if an error occurs before the end.
let mut guard;
// On linux, we pass the `SOCK_CLOEXEC` flag to atomically create the socket and set it as
// CLOEXEC.
#[cfg(target_os = "linux")]
let fd = syscall!(socket(
family,
libc::SOCK_STREAM | libc::SOCK_CLOEXEC,
protocol
))?;
let fd = {
let fd = syscall!(socket(
family,
libc::SOCK_STREAM | libc::SOCK_CLOEXEC,
protocol,
))?;
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
fd
};
// On other systems, we first create the socket and then set it as CLOEXEC.
#[cfg(not(target_os = "linux"))]
let fd = {
let fd = syscall!(socket(family, libc::SOCK_STREAM, protocol))?;
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
let flags = syscall!(fcntl(fd, libc::F_GETFD))? | libc::FD_CLOEXEC;
syscall!(fcntl(fd, libc::F_SETFD, flags))?;
@ -140,6 +149,9 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
Err(err) => return Err(err),
}
// Disarm the guard so that it doesn't close the file descriptor.
guard.0.take();
Ok(fd)
}
@ -313,6 +325,9 @@ fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
socket => socket,
};
// Create a TCP stream now so that it closes the socket if an error occurs before the end.
let stream = TcpStream::from_raw_socket(socket as _);
// Set no inherit.
if SetHandleInformation(socket as HANDLE, HANDLE_FLAG_INHERIT, 0) == 0 {
return Err(io::Error::last_os_error());
@ -333,6 +348,17 @@ fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
},
}
Ok(TcpStream::from_raw_socket(socket as _))
Ok(stream)
}
}
/// Runs a closure when dropped.
struct CallOnDrop<F: FnOnce()>(Option<F>);
impl<F: FnOnce()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
if let Some(f) = self.0.take() {
f();
}
}
}