h1: Fix connection with multiple IPs for a hostname

When trying to connect to multiple IPs for a hostname (e.g. IPv4 and
IPv6) we ought to try all prior returning error.

Running a wget to the running mockito server has this output:

,----
| $ wget -O- http://localhost:1234/report
| --2021-03-08 16:13:12--  http://localhost:1234/report
| Resolving localhost (localhost)... ::1, 127.0.0.1
| Connecting to localhost (localhost)|::1|:1234... failed: Connection refused.
| Connecting to localhost (localhost)|127.0.0.1|:1234... connected.
| HTTP request sent, awaiting response... 200 OK
`----

Fixes: #79.
Signed-off-by: Otavio Salvador <otavio@ossystems.com.br>
This commit is contained in:
Otavio Salvador 2021-03-09 22:38:50 -03:00
parent db0025ddff
commit e5bbc272d1
2 changed files with 87 additions and 59 deletions

View File

@ -134,72 +134,84 @@ impl HttpClient for H1Client {
));
}
let addr = req
.url()
.socket_addrs(|| match req.url().scheme() {
"http" => Some(80),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => Some(443),
_ => None,
})?
.into_iter()
.next()
.ok_or_else(|| Error::from_str(StatusCode::BadRequest, "missing valid address"))?;
let addrs = req.url().socket_addrs(|| match req.url().scheme() {
"http" => Some(80),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => Some(443),
_ => None,
})?;
log::trace!("> Scheme: {}", scheme);
match scheme {
"http" => {
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
pool_ref
} else {
let manager = TcpConnection::new(addr);
let pool = Pool::<TcpStream, std::io::Error>::new(
manager,
self.max_concurrent_connections,
);
self.http_pools.insert(addr, pool);
self.http_pools.get(&addr).unwrap()
};
let max_addrs_idx = addrs.len() - 1;
for (idx, addr) in addrs.into_iter().enumerate() {
let has_another_addr = idx != max_addrs_idx;
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);
match scheme {
"http" => {
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
pool_ref
} else {
let manager = TcpConnection::new(addr);
let pool = Pool::<TcpStream, std::io::Error>::new(
manager,
self.max_concurrent_connections,
);
self.http_pools.insert(addr, pool);
self.http_pools.get(&addr).unwrap()
};
let stream = pool.get().await?;
req.set_peer_addr(stream.peer_addr().ok());
req.set_local_addr(stream.local_addr().ok());
client::connect(TcpConnWrapper::new(stream), req).await
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);
let stream = match pool.get().await {
Ok(s) => s,
Err(_) if has_another_addr => continue,
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
};
req.set_peer_addr(stream.peer_addr().ok());
req.set_local_addr(stream.local_addr().ok());
return client::connect(TcpConnWrapper::new(stream), req).await;
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => {
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
pool_ref
} else {
let manager = TlsConnection::new(host.clone(), addr);
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
manager,
self.max_concurrent_connections,
);
self.https_pools.insert(addr, pool);
self.https_pools.get(&addr).unwrap()
};
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);
let stream = match pool.get().await {
Ok(s) => s,
Err(_) if has_another_addr => continue,
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
};
req.set_peer_addr(stream.get_ref().peer_addr().ok());
req.set_local_addr(stream.get_ref().local_addr().ok());
return client::connect(TlsConnWrapper::new(stream), req).await;
}
_ => unreachable!(),
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
"https" => {
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
pool_ref
} else {
let manager = TlsConnection::new(host.clone(), addr);
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
manager,
self.max_concurrent_connections,
);
self.https_pools.insert(addr, pool);
self.https_pools.get(&addr).unwrap()
};
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
let pool = pool_ref.clone();
std::mem::drop(pool_ref);
let stream = pool
.get()
.await
.map_err(|e| Error::from_str(400, e.to_string()))?;
req.set_peer_addr(stream.get_ref().peer_addr().ok());
req.set_local_addr(stream.get_ref().local_addr().ok());
client::connect(TlsConnWrapper::new(stream), req).await
}
_ => unreachable!(),
}
Err(Error::from_str(
StatusCode::BadRequest,
"missing valid address",
))
}
}

View File

@ -149,3 +149,19 @@ async fn keep_alive() {
client.send(req.clone()).await.unwrap();
client.send(req.clone()).await.unwrap();
}
#[atest]
async fn fallback_to_ipv4() {
let client = DefaultClient::new();
let _mock_guard = mock("GET", "/")
.with_status(200)
.expect_at_least(2)
.create();
// Kips the initial "http://127.0.0.1:" to get only the port number
let mock_port = &mockito::server_url()[17..];
let url = &format!("http://localhost:{}", mock_port);
let req = Request::new(http_types::Method::Get, Url::parse(url).unwrap());
client.send(req.clone()).await.unwrap();
}