From e5bbc272d1bf107f2ee9814c772e7a0d30710193 Mon Sep 17 00:00:00 2001 From: Otavio Salvador Date: Tue, 9 Mar 2021 22:38:50 -0300 Subject: [PATCH] h1: Fix connection with multiple IPs for a hostname When trying to connect to multiple IPs for a hostname (e.g. IPv4 and IPv6) we ought to try all prior returning error. Running a wget to the running mockito server has this output: ,---- | $ wget -O- http://localhost:1234/report | --2021-03-08 16:13:12-- http://localhost:1234/report | Resolving localhost (localhost)... ::1, 127.0.0.1 | Connecting to localhost (localhost)|::1|:1234... failed: Connection refused. | Connecting to localhost (localhost)|127.0.0.1|:1234... connected. | HTTP request sent, awaiting response... 200 OK `---- Fixes: #79. Signed-off-by: Otavio Salvador --- src/h1/mod.rs | 130 +++++++++++++++++++++++++++----------------------- tests/test.rs | 16 +++++++ 2 files changed, 87 insertions(+), 59 deletions(-) diff --git a/src/h1/mod.rs b/src/h1/mod.rs index f0a4df9..a94a204 100644 --- a/src/h1/mod.rs +++ b/src/h1/mod.rs @@ -134,72 +134,84 @@ impl HttpClient for H1Client { )); } - let addr = req - .url() - .socket_addrs(|| match req.url().scheme() { - "http" => Some(80), - #[cfg(any(feature = "native-tls", feature = "rustls"))] - "https" => Some(443), - _ => None, - })? - .into_iter() - .next() - .ok_or_else(|| Error::from_str(StatusCode::BadRequest, "missing valid address"))?; + let addrs = req.url().socket_addrs(|| match req.url().scheme() { + "http" => Some(80), + #[cfg(any(feature = "native-tls", feature = "rustls"))] + "https" => Some(443), + _ => None, + })?; log::trace!("> Scheme: {}", scheme); - match scheme { - "http" => { - let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) { - pool_ref - } else { - let manager = TcpConnection::new(addr); - let pool = Pool::::new( - manager, - self.max_concurrent_connections, - ); - self.http_pools.insert(addr, pool); - self.http_pools.get(&addr).unwrap() - }; + let max_addrs_idx = addrs.len() - 1; + for (idx, addr) in addrs.into_iter().enumerate() { + let has_another_addr = idx != max_addrs_idx; - // Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await. - let pool = pool_ref.clone(); - std::mem::drop(pool_ref); + match scheme { + "http" => { + let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) { + pool_ref + } else { + let manager = TcpConnection::new(addr); + let pool = Pool::::new( + manager, + self.max_concurrent_connections, + ); + self.http_pools.insert(addr, pool); + self.http_pools.get(&addr).unwrap() + }; - let stream = pool.get().await?; - req.set_peer_addr(stream.peer_addr().ok()); - req.set_local_addr(stream.local_addr().ok()); - client::connect(TcpConnWrapper::new(stream), req).await + // Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await. + let pool = pool_ref.clone(); + std::mem::drop(pool_ref); + + let stream = match pool.get().await { + Ok(s) => s, + Err(_) if has_another_addr => continue, + Err(e) => return Err(Error::from_str(400, e.to_string()))?, + }; + + req.set_peer_addr(stream.peer_addr().ok()); + req.set_local_addr(stream.local_addr().ok()); + return client::connect(TcpConnWrapper::new(stream), req).await; + } + #[cfg(any(feature = "native-tls", feature = "rustls"))] + "https" => { + let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) { + pool_ref + } else { + let manager = TlsConnection::new(host.clone(), addr); + let pool = Pool::, Error>::new( + manager, + self.max_concurrent_connections, + ); + self.https_pools.insert(addr, pool); + self.https_pools.get(&addr).unwrap() + }; + + // Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await. + let pool = pool_ref.clone(); + std::mem::drop(pool_ref); + + let stream = match pool.get().await { + Ok(s) => s, + Err(_) if has_another_addr => continue, + Err(e) => return Err(Error::from_str(400, e.to_string()))?, + }; + + req.set_peer_addr(stream.get_ref().peer_addr().ok()); + req.set_local_addr(stream.get_ref().local_addr().ok()); + + return client::connect(TlsConnWrapper::new(stream), req).await; + } + _ => unreachable!(), } - #[cfg(any(feature = "native-tls", feature = "rustls"))] - "https" => { - let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) { - pool_ref - } else { - let manager = TlsConnection::new(host.clone(), addr); - let pool = Pool::, Error>::new( - manager, - self.max_concurrent_connections, - ); - self.https_pools.insert(addr, pool); - self.https_pools.get(&addr).unwrap() - }; - - // Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await. - let pool = pool_ref.clone(); - std::mem::drop(pool_ref); - - let stream = pool - .get() - .await - .map_err(|e| Error::from_str(400, e.to_string()))?; - req.set_peer_addr(stream.get_ref().peer_addr().ok()); - req.set_local_addr(stream.get_ref().local_addr().ok()); - - client::connect(TlsConnWrapper::new(stream), req).await - } - _ => unreachable!(), } + + Err(Error::from_str( + StatusCode::BadRequest, + "missing valid address", + )) } } diff --git a/tests/test.rs b/tests/test.rs index d54579c..c813d0d 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -149,3 +149,19 @@ async fn keep_alive() { client.send(req.clone()).await.unwrap(); client.send(req.clone()).await.unwrap(); } + +#[atest] +async fn fallback_to_ipv4() { + let client = DefaultClient::new(); + let _mock_guard = mock("GET", "/") + .with_status(200) + .expect_at_least(2) + .create(); + + // Kips the initial "http://127.0.0.1:" to get only the port number + let mock_port = &mockito::server_url()[17..]; + + let url = &format!("http://localhost:{}", mock_port); + let req = Request::new(http_types::Method::Get, Url::parse(url).unwrap()); + client.send(req.clone()).await.unwrap(); +}