mirror of https://github.com/http-rs/http-client
h1: Fix connection with multiple IPs for a hostname
When trying to connect to multiple IPs for a hostname (e.g. IPv4 and IPv6) we ought to try all prior returning error. Running a wget to the running mockito server has this output: ,---- | $ wget -O- http://localhost:1234/report | --2021-03-08 16:13:12-- http://localhost:1234/report | Resolving localhost (localhost)... ::1, 127.0.0.1 | Connecting to localhost (localhost)|::1|:1234... failed: Connection refused. | Connecting to localhost (localhost)|127.0.0.1|:1234... connected. | HTTP request sent, awaiting response... 200 OK `---- Fixes: #79. Signed-off-by: Otavio Salvador <otavio@ossystems.com.br>
This commit is contained in:
parent
db0025ddff
commit
e5bbc272d1
130
src/h1/mod.rs
130
src/h1/mod.rs
|
@ -134,72 +134,84 @@ impl HttpClient for H1Client {
|
|||
));
|
||||
}
|
||||
|
||||
let addr = req
|
||||
.url()
|
||||
.socket_addrs(|| match req.url().scheme() {
|
||||
"http" => Some(80),
|
||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||
"https" => Some(443),
|
||||
_ => None,
|
||||
})?
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| Error::from_str(StatusCode::BadRequest, "missing valid address"))?;
|
||||
let addrs = req.url().socket_addrs(|| match req.url().scheme() {
|
||||
"http" => Some(80),
|
||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||
"https" => Some(443),
|
||||
_ => None,
|
||||
})?;
|
||||
|
||||
log::trace!("> Scheme: {}", scheme);
|
||||
|
||||
match scheme {
|
||||
"http" => {
|
||||
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
|
||||
pool_ref
|
||||
} else {
|
||||
let manager = TcpConnection::new(addr);
|
||||
let pool = Pool::<TcpStream, std::io::Error>::new(
|
||||
manager,
|
||||
self.max_concurrent_connections,
|
||||
);
|
||||
self.http_pools.insert(addr, pool);
|
||||
self.http_pools.get(&addr).unwrap()
|
||||
};
|
||||
let max_addrs_idx = addrs.len() - 1;
|
||||
for (idx, addr) in addrs.into_iter().enumerate() {
|
||||
let has_another_addr = idx != max_addrs_idx;
|
||||
|
||||
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
|
||||
let pool = pool_ref.clone();
|
||||
std::mem::drop(pool_ref);
|
||||
match scheme {
|
||||
"http" => {
|
||||
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
|
||||
pool_ref
|
||||
} else {
|
||||
let manager = TcpConnection::new(addr);
|
||||
let pool = Pool::<TcpStream, std::io::Error>::new(
|
||||
manager,
|
||||
self.max_concurrent_connections,
|
||||
);
|
||||
self.http_pools.insert(addr, pool);
|
||||
self.http_pools.get(&addr).unwrap()
|
||||
};
|
||||
|
||||
let stream = pool.get().await?;
|
||||
req.set_peer_addr(stream.peer_addr().ok());
|
||||
req.set_local_addr(stream.local_addr().ok());
|
||||
client::connect(TcpConnWrapper::new(stream), req).await
|
||||
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
|
||||
let pool = pool_ref.clone();
|
||||
std::mem::drop(pool_ref);
|
||||
|
||||
let stream = match pool.get().await {
|
||||
Ok(s) => s,
|
||||
Err(_) if has_another_addr => continue,
|
||||
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
|
||||
};
|
||||
|
||||
req.set_peer_addr(stream.peer_addr().ok());
|
||||
req.set_local_addr(stream.local_addr().ok());
|
||||
return client::connect(TcpConnWrapper::new(stream), req).await;
|
||||
}
|
||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||
"https" => {
|
||||
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
|
||||
pool_ref
|
||||
} else {
|
||||
let manager = TlsConnection::new(host.clone(), addr);
|
||||
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
|
||||
manager,
|
||||
self.max_concurrent_connections,
|
||||
);
|
||||
self.https_pools.insert(addr, pool);
|
||||
self.https_pools.get(&addr).unwrap()
|
||||
};
|
||||
|
||||
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
|
||||
let pool = pool_ref.clone();
|
||||
std::mem::drop(pool_ref);
|
||||
|
||||
let stream = match pool.get().await {
|
||||
Ok(s) => s,
|
||||
Err(_) if has_another_addr => continue,
|
||||
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
|
||||
};
|
||||
|
||||
req.set_peer_addr(stream.get_ref().peer_addr().ok());
|
||||
req.set_local_addr(stream.get_ref().local_addr().ok());
|
||||
|
||||
return client::connect(TlsConnWrapper::new(stream), req).await;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||
"https" => {
|
||||
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
|
||||
pool_ref
|
||||
} else {
|
||||
let manager = TlsConnection::new(host.clone(), addr);
|
||||
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
|
||||
manager,
|
||||
self.max_concurrent_connections,
|
||||
);
|
||||
self.https_pools.insert(addr, pool);
|
||||
self.https_pools.get(&addr).unwrap()
|
||||
};
|
||||
|
||||
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
|
||||
let pool = pool_ref.clone();
|
||||
std::mem::drop(pool_ref);
|
||||
|
||||
let stream = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| Error::from_str(400, e.to_string()))?;
|
||||
req.set_peer_addr(stream.get_ref().peer_addr().ok());
|
||||
req.set_local_addr(stream.get_ref().local_addr().ok());
|
||||
|
||||
client::connect(TlsConnWrapper::new(stream), req).await
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
Err(Error::from_str(
|
||||
StatusCode::BadRequest,
|
||||
"missing valid address",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -149,3 +149,19 @@ async fn keep_alive() {
|
|||
client.send(req.clone()).await.unwrap();
|
||||
client.send(req.clone()).await.unwrap();
|
||||
}
|
||||
|
||||
#[atest]
|
||||
async fn fallback_to_ipv4() {
|
||||
let client = DefaultClient::new();
|
||||
let _mock_guard = mock("GET", "/")
|
||||
.with_status(200)
|
||||
.expect_at_least(2)
|
||||
.create();
|
||||
|
||||
// Kips the initial "http://127.0.0.1:" to get only the port number
|
||||
let mock_port = &mockito::server_url()[17..];
|
||||
|
||||
let url = &format!("http://localhost:{}", mock_port);
|
||||
let req = Request::new(http_types::Method::Get, Url::parse(url).unwrap());
|
||||
client.send(req.clone()).await.unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue