mirror of https://github.com/smol-rs/nb-connect
commit
65f51426d1
|
@ -13,7 +13,7 @@ jobs:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||||
rust: [nightly, beta, stable, 1.39.0]
|
rust: [nightly, beta, stable, 1.46.0]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,11 @@ keywords = ["TcpStream", "UnixStream", "socket2", "polling"]
|
||||||
categories = ["asynchronous", "network-programming", "os"]
|
categories = ["asynchronous", "network-programming", "os"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
socket2 = { version = "0.3.19", features = ["unix"] }
|
||||||
|
|
||||||
[target."cfg(unix)".dependencies]
|
[target."cfg(unix)".dependencies]
|
||||||
libc = "0.2.77"
|
libc = "0.2.77"
|
||||||
|
|
||||||
[target.'cfg(windows)'.dependencies]
|
|
||||||
winapi = { version = "0.3.9", features = ["handleapi", "ws2tcpip"] }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
polling = "2.0.0"
|
polling = "2.0.0"
|
||||||
|
|
289
src/lib.rs
289
src/lib.rs
|
@ -33,125 +33,49 @@
|
||||||
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
|
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
|
||||||
|
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::mem::{self, MaybeUninit};
|
|
||||||
use std::net::{SocketAddr, TcpStream};
|
use std::net::{SocketAddr, TcpStream};
|
||||||
use std::ptr;
|
|
||||||
|
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use {
|
use std::{os::unix::net::UnixStream, path::Path};
|
||||||
libc::{sockaddr, sockaddr_storage, socklen_t},
|
|
||||||
std::os::unix::net::UnixStream,
|
|
||||||
std::os::unix::prelude::{FromRawFd, RawFd},
|
|
||||||
std::path::Path,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(windows)]
|
fn connect(addr: SockAddr, domain: Domain, protocol: Option<Protocol>) -> io::Result<Socket> {
|
||||||
use {
|
let sock_type = Type::stream();
|
||||||
std::os::windows::io::FromRawSocket,
|
#[cfg(any(
|
||||||
winapi::shared::ws2def::{SOCKADDR as sockaddr, SOCKADDR_STORAGE as sockaddr_storage},
|
target_os = "android",
|
||||||
winapi::um::ws2tcpip::socklen_t,
|
target_os = "dragonfly",
|
||||||
};
|
target_os = "freebsd",
|
||||||
|
target_os = "fuchsia",
|
||||||
/// A raw socket address.
|
target_os = "illumos",
|
||||||
struct Addr {
|
target_os = "linux",
|
||||||
storage: sockaddr_storage,
|
target_os = "netbsd",
|
||||||
len: socklen_t,
|
target_os = "openbsd"
|
||||||
}
|
))]
|
||||||
|
// If we can, set nonblocking at socket creation for unix
|
||||||
impl Addr {
|
let sock_type = sock_type.non_blocking();
|
||||||
/// Creates a raw socket address from `SocketAddr`.
|
// This automatically handles cloexec on unix, no_inherit on windows and nosigpipe on macos
|
||||||
fn new(addr: SocketAddr) -> Self {
|
let socket = Socket::new(domain, sock_type, protocol)?;
|
||||||
let (addr, len): (*const sockaddr, socklen_t) = match &addr {
|
#[cfg(not(any(
|
||||||
SocketAddr::V4(addr) => (addr as *const _ as *const _, mem::size_of_val(addr) as _),
|
target_os = "android",
|
||||||
SocketAddr::V6(addr) => (addr as *const _ as *const _, mem::size_of_val(addr) as _),
|
target_os = "dragonfly",
|
||||||
};
|
target_os = "freebsd",
|
||||||
unsafe { Self::from_raw_parts(addr, len) }
|
target_os = "fuchsia",
|
||||||
}
|
target_os = "illumos",
|
||||||
|
target_os = "linux",
|
||||||
/// Creates an `Addr` from its raw parts.
|
target_os = "netbsd",
|
||||||
unsafe fn from_raw_parts(addr: *const sockaddr, len: socklen_t) -> Self {
|
target_os = "openbsd"
|
||||||
let mut storage = MaybeUninit::<sockaddr_storage>::uninit();
|
)))]
|
||||||
ptr::copy_nonoverlapping(
|
// If the current platform doesn't support nonblocking at creation, enable it after creation
|
||||||
addr as *const _ as *const u8,
|
socket.set_nonblocking(true)?;
|
||||||
&mut storage as *mut _ as *mut u8,
|
match socket.connect(&addr) {
|
||||||
len as usize,
|
|
||||||
);
|
|
||||||
Self {
|
|
||||||
storage: storage.assume_init(),
|
|
||||||
len,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(unix)]
|
|
||||||
fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result<RawFd> {
|
|
||||||
/// Calls a libc function and results in `io::Result`.
|
|
||||||
macro_rules! syscall {
|
|
||||||
($fn:ident $args:tt) => {{
|
|
||||||
let res = unsafe { libc::$fn $args };
|
|
||||||
if res == -1 {
|
|
||||||
Err(std::io::Error::last_os_error())
|
|
||||||
} else {
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
}};
|
|
||||||
}
|
|
||||||
|
|
||||||
// A guard that closes the file descriptor if an error occurs before the end.
|
|
||||||
let mut guard;
|
|
||||||
|
|
||||||
// On linux, we pass the `SOCK_CLOEXEC` flag to atomically create the socket and set it as
|
|
||||||
// CLOEXEC.
|
|
||||||
#[cfg(target_os = "linux")]
|
|
||||||
let fd = {
|
|
||||||
let fd = syscall!(socket(
|
|
||||||
family,
|
|
||||||
libc::SOCK_STREAM | libc::SOCK_CLOEXEC,
|
|
||||||
protocol,
|
|
||||||
))?;
|
|
||||||
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
|
|
||||||
fd
|
|
||||||
};
|
|
||||||
|
|
||||||
// On other systems, we first create the socket and then set it as CLOEXEC.
|
|
||||||
#[cfg(not(target_os = "linux"))]
|
|
||||||
let fd = {
|
|
||||||
let fd = syscall!(socket(family, libc::SOCK_STREAM, protocol))?;
|
|
||||||
guard = CallOnDrop(Some(move || drop(syscall!(close(fd)))));
|
|
||||||
|
|
||||||
let flags = syscall!(fcntl(fd, libc::F_GETFD))? | libc::FD_CLOEXEC;
|
|
||||||
syscall!(fcntl(fd, libc::F_SETFD, flags))?;
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "macos", target_os = "ios"))]
|
|
||||||
{
|
|
||||||
let payload = &1i32 as *const i32 as *const libc::c_void;
|
|
||||||
syscall!(setsockopt(
|
|
||||||
fd,
|
|
||||||
libc::SOL_SOCKET,
|
|
||||||
libc::SO_NOSIGPIPE,
|
|
||||||
payload,
|
|
||||||
std::mem::size_of::<i32>() as libc::socklen_t,
|
|
||||||
))?;
|
|
||||||
}
|
|
||||||
fd
|
|
||||||
};
|
|
||||||
|
|
||||||
// Put socket into non-blocking mode.
|
|
||||||
let flags = syscall!(fcntl(fd, libc::F_GETFL))? | libc::O_NONBLOCK;
|
|
||||||
syscall!(fcntl(fd, libc::F_SETFL, flags))?;
|
|
||||||
|
|
||||||
// Start connecting.
|
|
||||||
match syscall!(connect(fd, &addr.storage as *const _ as *const _, addr.len)) {
|
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
|
#[cfg(unix)]
|
||||||
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
|
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
|
||||||
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
|
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
|
||||||
Err(err) => return Err(err),
|
Err(err) => return Err(err),
|
||||||
}
|
}
|
||||||
|
Ok(socket)
|
||||||
// Disarm the guard so that it doesn't close the file descriptor.
|
|
||||||
guard.0.take();
|
|
||||||
|
|
||||||
Ok(fd)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a pending Unix connection to the specified path.
|
/// Creates a pending Unix connection to the specified path.
|
||||||
|
@ -184,52 +108,8 @@ fn connect(addr: Addr, family: libc::c_int, protocol: libc::c_int) -> io::Result
|
||||||
/// ```
|
/// ```
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
pub fn unix<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
|
pub fn unix<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
|
||||||
use std::cmp::Ordering;
|
let socket = connect(SockAddr::unix(path)?, Domain::unix(), None)?;
|
||||||
use std::os::unix::ffi::OsStrExt;
|
Ok(socket.into())
|
||||||
|
|
||||||
let addr = unsafe {
|
|
||||||
let mut addr = mem::zeroed::<libc::sockaddr_un>();
|
|
||||||
addr.sun_family = libc::AF_UNIX as libc::sa_family_t;
|
|
||||||
|
|
||||||
let bytes = path.as_ref().as_os_str().as_bytes();
|
|
||||||
|
|
||||||
match (bytes.get(0), bytes.len().cmp(&addr.sun_path.len())) {
|
|
||||||
// Abstract paths don't need a null terminator
|
|
||||||
(Some(&0), Ordering::Greater) => {
|
|
||||||
return Err(io::Error::new(
|
|
||||||
io::ErrorKind::InvalidInput,
|
|
||||||
"path must be no longer than SUN_LEN",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
(Some(&0), _) => {}
|
|
||||||
(_, Ordering::Greater) | (_, Ordering::Equal) => {
|
|
||||||
return Err(io::Error::new(
|
|
||||||
io::ErrorKind::InvalidInput,
|
|
||||||
"path must be shorter than SUN_LEN",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (dst, src) in addr.sun_path.iter_mut().zip(bytes) {
|
|
||||||
*dst = *src as libc::c_char;
|
|
||||||
}
|
|
||||||
// null byte for pathname is already there since we zeroed up front
|
|
||||||
|
|
||||||
let base = &addr as *const _ as usize;
|
|
||||||
let path = &addr.sun_path as *const _ as usize;
|
|
||||||
let sun_path_offset = path - base;
|
|
||||||
|
|
||||||
let mut len = sun_path_offset + bytes.len();
|
|
||||||
match bytes.get(0) {
|
|
||||||
Some(&0) | None => {}
|
|
||||||
Some(_) => len += 1,
|
|
||||||
}
|
|
||||||
Addr::from_raw_parts(&addr as *const _ as *const _, len as libc::socklen_t)
|
|
||||||
};
|
|
||||||
|
|
||||||
let fd = connect(addr, libc::AF_UNIX, 0)?;
|
|
||||||
unsafe { Ok(UnixStream::from_raw_fd(fd)) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a pending TCP connection to the specified address.
|
/// Creates a pending TCP connection to the specified address.
|
||||||
|
@ -263,99 +143,8 @@ pub fn unix<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
|
||||||
/// # std::io::Result::Ok(())
|
/// # std::io::Result::Ok(())
|
||||||
/// ```
|
/// ```
|
||||||
pub fn tcp<A: Into<SocketAddr>>(addr: A) -> io::Result<TcpStream> {
|
pub fn tcp<A: Into<SocketAddr>>(addr: A) -> io::Result<TcpStream> {
|
||||||
tcp_connect(addr.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(unix)]
|
|
||||||
fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
|
|
||||||
let addr = addr.into();
|
let addr = addr.into();
|
||||||
let fd = connect(
|
let domain = if addr.is_ipv6() { Domain::ipv6() } else { Domain::ipv4() };
|
||||||
Addr::new(addr),
|
let socket = connect(addr.into(), domain, Some(Protocol::tcp()))?;
|
||||||
if addr.is_ipv6() {
|
Ok(socket.into())
|
||||||
libc::AF_INET6
|
|
||||||
} else {
|
|
||||||
libc::AF_INET
|
|
||||||
},
|
|
||||||
libc::IPPROTO_TCP,
|
|
||||||
)?;
|
|
||||||
unsafe { Ok(TcpStream::from_raw_fd(fd)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(windows)]
|
|
||||||
fn tcp_connect(addr: SocketAddr) -> io::Result<TcpStream> {
|
|
||||||
use std::net::UdpSocket;
|
|
||||||
use std::sync::Once;
|
|
||||||
|
|
||||||
use winapi::ctypes::{c_int, c_ulong};
|
|
||||||
use winapi::shared::minwindef::DWORD;
|
|
||||||
use winapi::shared::ntdef::HANDLE;
|
|
||||||
use winapi::shared::ws2def::{AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM};
|
|
||||||
use winapi::um::handleapi::SetHandleInformation;
|
|
||||||
use winapi::um::winsock2 as sock;
|
|
||||||
|
|
||||||
static INIT: Once = Once::new();
|
|
||||||
INIT.call_once(|| {
|
|
||||||
// Initialize winsock through the standard library by just creating a dummy socket.
|
|
||||||
// Whether this is successful or not we drop the result as libstd will be sure to have
|
|
||||||
// initialized winsock.
|
|
||||||
let _ = UdpSocket::bind("127.0.0.1:34254");
|
|
||||||
});
|
|
||||||
|
|
||||||
const HANDLE_FLAG_INHERIT: DWORD = 0x00000001;
|
|
||||||
const WSA_FLAG_OVERLAPPED: DWORD = 0x01;
|
|
||||||
|
|
||||||
let family = if addr.is_ipv6() { AF_INET6 } else { AF_INET };
|
|
||||||
let addr = Addr::new(addr);
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
let socket = match sock::WSASocketW(
|
|
||||||
family,
|
|
||||||
SOCK_STREAM,
|
|
||||||
IPPROTO_TCP as _,
|
|
||||||
ptr::null_mut(),
|
|
||||||
0,
|
|
||||||
WSA_FLAG_OVERLAPPED,
|
|
||||||
) {
|
|
||||||
sock::INVALID_SOCKET => {
|
|
||||||
return Err(io::Error::from_raw_os_error(sock::WSAGetLastError()))
|
|
||||||
}
|
|
||||||
socket => socket,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create a TCP stream now so that it closes the socket if an error occurs before the end.
|
|
||||||
let stream = TcpStream::from_raw_socket(socket as _);
|
|
||||||
|
|
||||||
// Set no inherit.
|
|
||||||
if SetHandleInformation(socket as HANDLE, HANDLE_FLAG_INHERIT, 0) == 0 {
|
|
||||||
return Err(io::Error::last_os_error());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put socket into nonblocking mode.
|
|
||||||
let mut nonblocking = true as c_ulong;
|
|
||||||
if sock::ioctlsocket(socket, sock::FIONBIO as c_int, &mut nonblocking) != 0 {
|
|
||||||
return Err(io::Error::last_os_error());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start connecting.
|
|
||||||
match sock::connect(socket, &addr.storage as *const _ as *const _, addr.len) {
|
|
||||||
0 => {}
|
|
||||||
_ => match io::Error::from_raw_os_error(sock::WSAGetLastError()) {
|
|
||||||
err if err.kind() == io::ErrorKind::WouldBlock => {}
|
|
||||||
err => return Err(err),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Runs a closure when dropped.
|
|
||||||
struct CallOnDrop<F: FnOnce()>(Option<F>);
|
|
||||||
|
|
||||||
impl<F: FnOnce()> Drop for CallOnDrop<F> {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
if let Some(f) = self.0.take() {
|
|
||||||
f();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue