mirror of https://github.com/http-rs/async-h1
keep-alive: always read body to end
This commit is contained in:
parent
69bed38fd7
commit
7df79f1d5d
|
@ -27,12 +27,8 @@ futures-core = "0.3.8"
|
|||
log = "0.4.11"
|
||||
pin-project = "1.0.2"
|
||||
async-channel = "1.5.1"
|
||||
async-dup = "1.2.2"
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "0.6.1"
|
||||
async-std = { version = "1.7.0", features = ["attributes"] }
|
||||
tempfile = "3.1.0"
|
||||
async-test = "1.0.0"
|
||||
duplexify = "1.2.2"
|
||||
async-dup = "1.2.2"
|
||||
async-channel = "1.5.1"
|
||||
|
|
|
@ -19,7 +19,8 @@ lazy_static::lazy_static! {
|
|||
|
||||
/// Decodes a chunked body according to
|
||||
/// https://tools.ietf.org/html/rfc7230#section-4.1
|
||||
pub(crate) struct ChunkedDecoder<R: Read> {
|
||||
#[derive(Debug)]
|
||||
pub struct ChunkedDecoder<R: Read> {
|
||||
/// The underlying stream
|
||||
inner: R,
|
||||
/// Buffer for the already read, but not yet parsed data.
|
||||
|
|
|
@ -26,7 +26,7 @@ impl<B> fmt::Debug for ReadNotifier<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: BufRead> ReadNotifier<B> {
|
||||
impl<B: Read> ReadNotifier<B> {
|
||||
pub(crate) fn new(reader: B, sender: Sender<()>) -> Self {
|
||||
Self {
|
||||
reader,
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
use crate::chunked::ChunkedDecoder;
|
||||
use async_dup::{Arc, Mutex};
|
||||
use async_std::io::{BufReader, Read, Take};
|
||||
use async_std::task::{Context, Poll};
|
||||
use std::{fmt::Debug, io, pin::Pin};
|
||||
|
||||
pub enum BodyReader<IO: Read + Unpin> {
|
||||
Chunked(Arc<Mutex<ChunkedDecoder<BufReader<IO>>>>),
|
||||
Fixed(Arc<Mutex<Take<BufReader<IO>>>>),
|
||||
None,
|
||||
}
|
||||
|
||||
impl<IO: Read + Unpin> Debug for BodyReader<IO> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
BodyReader::Chunked(_) => f.write_str("BodyReader::Chunked"),
|
||||
BodyReader::Fixed(_) => f.write_str("BodyReader::Fixed"),
|
||||
BodyReader::None => f.write_str("BodyReader::None"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: Read + Unpin> Read for BodyReader<IO> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match &*self {
|
||||
BodyReader::Chunked(r) => Pin::new(&mut *r.lock()).poll_read(cx, buf),
|
||||
BodyReader::Fixed(r) => Pin::new(&mut *r.lock()).poll_read(cx, buf),
|
||||
BodyReader::None => Poll::Ready(Ok(0)),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,12 +2,14 @@
|
|||
|
||||
use std::str::FromStr;
|
||||
|
||||
use async_dup::{Arc, Mutex};
|
||||
use async_std::io::{BufReader, Read, Write};
|
||||
use async_std::{prelude::*, task};
|
||||
use http_types::headers::{CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING};
|
||||
use http_types::{ensure, ensure_eq, format_err};
|
||||
use http_types::{Body, Method, Request, Url};
|
||||
|
||||
use super::body_reader::BodyReader;
|
||||
use crate::chunked::ChunkedDecoder;
|
||||
use crate::read_notifier::ReadNotifier;
|
||||
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
|
||||
|
@ -21,7 +23,7 @@ const CONTINUE_HEADER_VALUE: &str = "100-continue";
|
|||
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
|
||||
|
||||
/// Decode an HTTP request on the server.
|
||||
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<Request>>
|
||||
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
|
||||
where
|
||||
IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
|
||||
{
|
||||
|
@ -108,26 +110,29 @@ where
|
|||
}
|
||||
|
||||
// Check for Transfer-Encoding
|
||||
if let Some(encoding) = transfer_encoding {
|
||||
if encoding.last().as_str() == "chunked" {
|
||||
let trailer_sender = req.send_trailers();
|
||||
let reader = ChunkedDecoder::new(reader, trailer_sender);
|
||||
let reader = BufReader::new(reader);
|
||||
let reader = ReadNotifier::new(reader, body_read_sender);
|
||||
req.set_body(Body::from_reader(reader, None));
|
||||
return Ok(Some(req));
|
||||
}
|
||||
// Fall through to Content-Length
|
||||
}
|
||||
|
||||
// Check for Content-Length.
|
||||
if let Some(len) = content_length {
|
||||
if transfer_encoding
|
||||
.map(|te| te.as_str().eq_ignore_ascii_case("chunked"))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let trailer_sender = req.send_trailers();
|
||||
let reader = ChunkedDecoder::new(reader, trailer_sender);
|
||||
let reader = Arc::new(Mutex::new(reader));
|
||||
let reader_clone = reader.clone();
|
||||
let reader = ReadNotifier::new(reader, body_read_sender);
|
||||
let reader = BufReader::new(reader);
|
||||
req.set_body(Body::from_reader(reader, None));
|
||||
return Ok(Some((req, BodyReader::Chunked(reader_clone))));
|
||||
} else if let Some(len) = content_length {
|
||||
let len = len.last().as_str().parse::<usize>()?;
|
||||
let reader = ReadNotifier::new(reader.take(len as u64), body_read_sender);
|
||||
req.set_body(Body::from_reader(reader, Some(len)));
|
||||
let reader = Arc::new(Mutex::new(reader.take(len as u64)));
|
||||
req.set_body(Body::from_reader(
|
||||
BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
|
||||
Some(len),
|
||||
));
|
||||
Ok(Some((req, BodyReader::Fixed(reader))))
|
||||
} else {
|
||||
Ok(Some((req, BodyReader::None)))
|
||||
}
|
||||
|
||||
Ok(Some(req))
|
||||
}
|
||||
|
||||
fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<Url> {
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
//! Process HTTP connections on the server.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use async_std::future::{timeout, Future, TimeoutError};
|
||||
use async_std::io::{self, Read, Write};
|
||||
use http_types::headers::{CONNECTION, UPGRADE};
|
||||
use http_types::upgrade::Connection;
|
||||
use http_types::{Request, Response, StatusCode};
|
||||
|
||||
use std::{marker::PhantomData, time::Duration};
|
||||
mod body_reader;
|
||||
mod decode;
|
||||
mod encode;
|
||||
|
||||
|
@ -38,14 +37,14 @@ where
|
|||
F: Fn(Request) -> Fut,
|
||||
Fut: Future<Output = http_types::Result<Response>>,
|
||||
{
|
||||
accept_with_opts(io, endpoint, Default::default()).await
|
||||
Server::new(io, endpoint).accept().await
|
||||
}
|
||||
|
||||
/// Accept a new incoming HTTP/1.1 connection.
|
||||
///
|
||||
/// Supports `KeepAlive` requests by default.
|
||||
pub async fn accept_with_opts<RW, F, Fut>(
|
||||
mut io: RW,
|
||||
io: RW,
|
||||
endpoint: F,
|
||||
opts: ServerOptions,
|
||||
) -> http_types::Result<()>
|
||||
|
@ -54,35 +53,99 @@ where
|
|||
F: Fn(Request) -> Fut,
|
||||
Fut: Future<Output = http_types::Result<Response>>,
|
||||
{
|
||||
loop {
|
||||
// Decode a new request, timing out if this takes longer than the timeout duration.
|
||||
let fut = decode(io.clone());
|
||||
Server::new(io, endpoint).with_opts(opts).accept().await
|
||||
}
|
||||
|
||||
let req = if let Some(timeout_duration) = opts.headers_timeout {
|
||||
/// struct for server
|
||||
#[derive(Debug)]
|
||||
pub struct Server<RW, F, Fut> {
|
||||
io: RW,
|
||||
endpoint: F,
|
||||
opts: ServerOptions,
|
||||
_phantom: PhantomData<Fut>,
|
||||
}
|
||||
|
||||
/// An enum that represents whether the server should accept a subsequent request
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub enum ConnectionStatus {
|
||||
/// The server should not accept another request
|
||||
Close,
|
||||
|
||||
/// The server may accept another request
|
||||
KeepAlive,
|
||||
}
|
||||
|
||||
impl<RW, F, Fut> Server<RW, F, Fut>
|
||||
where
|
||||
RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
|
||||
F: Fn(Request) -> Fut,
|
||||
Fut: Future<Output = http_types::Result<Response>>,
|
||||
{
|
||||
/// builds a new server
|
||||
pub fn new(io: RW, endpoint: F) -> Self {
|
||||
Self {
|
||||
io,
|
||||
endpoint,
|
||||
opts: Default::default(),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// with opts
|
||||
pub fn with_opts(mut self, opts: ServerOptions) -> Self {
|
||||
self.opts = opts;
|
||||
self
|
||||
}
|
||||
|
||||
/// accept in a loop
|
||||
pub async fn accept(&mut self) -> http_types::Result<()> {
|
||||
while ConnectionStatus::KeepAlive == self.accept_one().await? {}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// accept one request
|
||||
pub async fn accept_one(&mut self) -> http_types::Result<ConnectionStatus>
|
||||
where
|
||||
RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
|
||||
F: Fn(Request) -> Fut,
|
||||
Fut: Future<Output = http_types::Result<Response>>,
|
||||
{
|
||||
// Decode a new request, timing out if this takes longer than the timeout duration.
|
||||
let fut = decode(self.io.clone());
|
||||
|
||||
let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
|
||||
match timeout(timeout_duration, fut).await {
|
||||
Ok(Ok(Some(r))) => r,
|
||||
Ok(Ok(None)) | Err(TimeoutError { .. }) => break, /* EOF or timeout */
|
||||
Ok(Ok(None)) | Err(TimeoutError { .. }) => return Ok(ConnectionStatus::Close), /* EOF or timeout */
|
||||
Ok(Err(e)) => return Err(e),
|
||||
}
|
||||
} else {
|
||||
match fut.await? {
|
||||
Some(r) => r,
|
||||
None => break, /* EOF */
|
||||
None => return Ok(ConnectionStatus::Close), /* EOF */
|
||||
}
|
||||
};
|
||||
|
||||
let has_upgrade_header = req.header(UPGRADE).is_some();
|
||||
let connection_header_is_upgrade = req
|
||||
let connection_header_as_str = req
|
||||
.header(CONNECTION)
|
||||
.map(|connection| connection.as_str().eq_ignore_ascii_case("upgrade"))
|
||||
.unwrap_or(false);
|
||||
.map(|connection| connection.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let connection_header_is_upgrade = connection_header_as_str.eq_ignore_ascii_case("upgrade");
|
||||
let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close");
|
||||
|
||||
let upgrade_requested = has_upgrade_header && connection_header_is_upgrade;
|
||||
|
||||
let method = req.method();
|
||||
|
||||
// Pass the request to the endpoint and encode the response.
|
||||
let mut res = endpoint(req).await?;
|
||||
let mut res = (self.endpoint)(req).await?;
|
||||
|
||||
close_connection |= res
|
||||
.header(CONNECTION)
|
||||
.map(|c| c.as_str().eq_ignore_ascii_case("close"))
|
||||
.unwrap_or(false);
|
||||
|
||||
let upgrade_provided = res.status() == StatusCode::SwitchingProtocols && res.has_upgrade();
|
||||
|
||||
|
@ -94,14 +157,22 @@ where
|
|||
|
||||
let mut encoder = Encoder::new(res, method);
|
||||
|
||||
// Stream the response to the writer.
|
||||
io::copy(&mut encoder, &mut io).await?;
|
||||
let bytes_written = io::copy(&mut encoder, &mut self.io).await?;
|
||||
log::trace!("wrote {} response bytes", bytes_written);
|
||||
|
||||
let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?;
|
||||
log::trace!(
|
||||
"discarded {} unread request body bytes",
|
||||
body_bytes_discarded
|
||||
);
|
||||
|
||||
if let Some(upgrade_sender) = upgrade_sender {
|
||||
upgrade_sender.send(Connection::new(io.clone())).await;
|
||||
return Ok(());
|
||||
upgrade_sender.send(Connection::new(self.io.clone())).await;
|
||||
return Ok(ConnectionStatus::Close);
|
||||
} else if close_connection {
|
||||
Ok(ConnectionStatus::Close)
|
||||
} else {
|
||||
Ok(ConnectionStatus::KeepAlive)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
mod test_utils;
|
||||
mod accept {
|
||||
use super::test_utils::TestServer;
|
||||
use async_h1::{client::Encoder, server::ConnectionStatus};
|
||||
use async_std::io::{self, prelude::WriteExt, Cursor};
|
||||
use http_types::{headers::CONNECTION, Body, Request, Response, Result};
|
||||
|
||||
#[async_std::test]
|
||||
async fn basic() -> Result<()> {
|
||||
let mut server = TestServer::new(|req| async {
|
||||
let mut response = Response::new(200);
|
||||
let len = req.len();
|
||||
response.set_body(Body::from_reader(req, len));
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
let content_length = 10;
|
||||
|
||||
let request_str = format!(
|
||||
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
|
||||
content_length,
|
||||
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
|
||||
);
|
||||
|
||||
server.write_all(request_str.as_bytes()).await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server.close();
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
|
||||
|
||||
assert!(server.all_read());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn request_close() -> Result<()> {
|
||||
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
|
||||
|
||||
server
|
||||
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nConnection: Close\r\n\r\n")
|
||||
.await?;
|
||||
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
|
||||
|
||||
assert!(server.all_read());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn response_close() -> Result<()> {
|
||||
let mut server = TestServer::new(|_| async {
|
||||
let mut response = Response::new(200);
|
||||
response.insert_header(CONNECTION, "close");
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
server
|
||||
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
.await?;
|
||||
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
|
||||
|
||||
assert!(server.all_read());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn keep_alive_short_fixed_length_unread_body() -> Result<()> {
|
||||
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
|
||||
|
||||
let content_length = 10;
|
||||
|
||||
let request_str = format!(
|
||||
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
|
||||
content_length,
|
||||
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
|
||||
);
|
||||
|
||||
server.write_all(request_str.as_bytes()).await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server.write_all(request_str.as_bytes()).await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server.close();
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
|
||||
|
||||
assert!(server.all_read());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn keep_alive_short_chunked_unread_body() -> Result<()> {
|
||||
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
|
||||
|
||||
let content_length = 100;
|
||||
|
||||
let mut request = Request::post("http://example.com/");
|
||||
request.set_body(Body::from_reader(
|
||||
Cursor::new(vec![b'|'; content_length]),
|
||||
None,
|
||||
));
|
||||
|
||||
io::copy(&mut Encoder::new(request), &mut server).await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server
|
||||
.write_fmt(format_args!(
|
||||
"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\n\r\n"
|
||||
))
|
||||
.await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server.close();
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
|
||||
|
||||
assert!(server.all_read());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn keep_alive_long_fixed_length_unread_body() -> Result<()> {
|
||||
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
|
||||
|
||||
let content_length = 10000;
|
||||
|
||||
let request_str = format!(
|
||||
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
|
||||
content_length,
|
||||
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
|
||||
);
|
||||
|
||||
server.write_all(request_str.as_bytes()).await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server.write_all(request_str.as_bytes()).await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server.close();
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
|
||||
|
||||
assert!(server.all_read());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn keep_alive_long_chunked_unread_body() -> Result<()> {
|
||||
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
|
||||
|
||||
let content_length = 10000;
|
||||
|
||||
let mut request = Request::post("http://example.com/");
|
||||
request.set_body(Body::from_reader(
|
||||
Cursor::new(vec![b'|'; content_length]),
|
||||
None,
|
||||
));
|
||||
|
||||
server.write_request(request).await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server
|
||||
.write_fmt(format_args!(
|
||||
"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\n\r\n"
|
||||
))
|
||||
.await?;
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
|
||||
|
||||
server.close();
|
||||
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
|
||||
|
||||
assert!(server.all_read());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -94,7 +94,6 @@ mod client_encode {
|
|||
)
|
||||
}
|
||||
|
||||
#[ignore = "this does not work yet"]
|
||||
#[async_std::test]
|
||||
async fn client_encode_chunked_body() -> Result<()> {
|
||||
let url = Url::parse("http://example.com/path?query").unwrap();
|
||||
|
@ -118,6 +117,7 @@ mod client_encode {
|
|||
"d",
|
||||
"0",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
@ -138,6 +138,7 @@ mod client_encode {
|
|||
"hello world",
|
||||
"0",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
@ -169,6 +170,7 @@ mod client_encode {
|
|||
"t",
|
||||
"0",
|
||||
"",
|
||||
"",
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use async_dup::{Arc, Mutex};
|
||||
use async_std::io::{Cursor, SeekFrom};
|
||||
use async_std::{prelude::*, task};
|
||||
use duplexify::Duplex;
|
||||
mod test_utils;
|
||||
|
||||
use async_std::{io, prelude::*, task};
|
||||
use http_types::Result;
|
||||
use std::time::Duration;
|
||||
use test_utils::TestIO;
|
||||
|
||||
const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
|
||||
Host: example.com\r\n\
|
||||
|
@ -13,63 +13,42 @@ Expect: 100-continue\r\n\r\n";
|
|||
const SLEEP_DURATION: Duration = std::time::Duration::from_millis(100);
|
||||
#[async_std::test]
|
||||
async fn test_with_expect_when_reading_body() -> Result<()> {
|
||||
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
|
||||
let server_str: Vec<u8> = vec![];
|
||||
let (mut client, server) = TestIO::new();
|
||||
client.write_all(REQUEST_WITH_EXPECT).await?;
|
||||
|
||||
let mut client = Arc::new(Mutex::new(Cursor::new(client_str)));
|
||||
let server = Arc::new(Mutex::new(Cursor::new(server_str)));
|
||||
|
||||
let mut request = async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
|
||||
.await?
|
||||
.unwrap();
|
||||
let (mut request, _) = async_h1::server::decode(server).await?.unwrap();
|
||||
|
||||
task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written
|
||||
|
||||
{
|
||||
let lock = server.lock();
|
||||
assert_eq!("", std::str::from_utf8(lock.get_ref())?); //we haven't written yet
|
||||
};
|
||||
assert_eq!("", &client.read.to_string()); // we haven't written yet
|
||||
|
||||
let mut buf = vec![0u8; 1];
|
||||
let bytes = request.read(&mut buf).await?; //this triggers the 100-continue even though there's nothing to read yet
|
||||
assert_eq!(bytes, 0); // normally we'd actually be waiting for the end of the buffer, but this lets us test this sequentially
|
||||
let join_handle = task::spawn(async move {
|
||||
let mut string = String::new();
|
||||
request.read_to_string(&mut string).await?; //this triggers the 100-continue even though there's nothing to read yet
|
||||
io::Result::Ok(string)
|
||||
});
|
||||
|
||||
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel and io
|
||||
|
||||
{
|
||||
let lock = server.lock();
|
||||
assert_eq!(
|
||||
"HTTP/1.1 100 Continue\r\n\r\n",
|
||||
std::str::from_utf8(lock.get_ref())?
|
||||
);
|
||||
};
|
||||
assert_eq!("HTTP/1.1 100 Continue\r\n\r\n", &client.read.to_string());
|
||||
|
||||
client.write_all(b"0123456789").await?;
|
||||
client
|
||||
.seek(SeekFrom::Start(REQUEST_WITH_EXPECT.len() as u64))
|
||||
.await?;
|
||||
|
||||
assert_eq!("0123456789", request.body_string().await?);
|
||||
assert_eq!("0123456789", &join_handle.await?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn test_without_expect_when_not_reading_body() -> Result<()> {
|
||||
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
|
||||
let server_str: Vec<u8> = vec![];
|
||||
let (mut client, server) = TestIO::new();
|
||||
client.write_all(REQUEST_WITH_EXPECT).await?;
|
||||
|
||||
let client = Arc::new(Mutex::new(Cursor::new(client_str)));
|
||||
let server = Arc::new(Mutex::new(Cursor::new(server_str)));
|
||||
|
||||
async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
|
||||
.await?
|
||||
.unwrap();
|
||||
let (_, _) = async_h1::server::decode(server).await?.unwrap();
|
||||
|
||||
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel
|
||||
|
||||
let server_lock = server.lock();
|
||||
assert_eq!("", std::str::from_utf8(server_lock.get_ref())?); // we haven't written 100-continue
|
||||
assert_eq!("", &client.read.to_string()); // we haven't written 100-continue
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use async_dup::{Arc, Mutex};
|
||||
use async_std::io::Cursor;
|
||||
use duplexify::Duplex;
|
||||
use async_h1::{client, server};
|
||||
use http_types::Body;
|
||||
use http_types::Method;
|
||||
use http_types::Request;
|
||||
use http_types::Url;
|
||||
use http_types::{Response, Result};
|
||||
use pretty_assertions::assert_eq;
|
||||
mod test_utils;
|
||||
use test_utils::TestIO;
|
||||
|
||||
const BODY: &str = concat![
|
||||
"Et provident reprehenderit accusamus dolores et voluptates sed quia. Repellendus odit porro ut et hic molestiae. Sit autem reiciendis animi fugiat deleniti vel iste. Laborum id odio ullam ut impedit dolores. Vel aperiam dolorem voluptatibus dignissimos maxime.",
|
||||
|
@ -72,21 +72,18 @@ async fn server_chunked_large() -> Result<()> {
|
|||
let mut request = Request::new(Method::Post, Url::parse("http://domain.com").unwrap());
|
||||
// request.set_body(Body::from_reader(Cursor::new(BODY), None));
|
||||
request.set_body(Body::from_string(String::from(BODY)));
|
||||
let request_encoder = async_h1::client::Encoder::new(request);
|
||||
|
||||
let request = async_h1::server::decode(Duplex::new(
|
||||
Arc::new(Mutex::new(request_encoder)),
|
||||
Arc::new(Mutex::new(Cursor::new(vec![]))),
|
||||
))
|
||||
.await?
|
||||
.unwrap();
|
||||
let (mut client, server) = TestIO::new();
|
||||
async_std::io::copy(&mut client::Encoder::new(request), &mut client).await?;
|
||||
|
||||
let (request, _) = server::decode(server).await?.unwrap();
|
||||
|
||||
let mut response = Response::new(200);
|
||||
response.set_body(Body::from_reader(request, None));
|
||||
|
||||
let response_encoder = async_h1::server::Encoder::new(response, Method::Get);
|
||||
let response_encoder = server::Encoder::new(response, Method::Get);
|
||||
|
||||
let mut response = async_h1::client::decode(response_encoder).await?;
|
||||
let mut response = client::decode(response_encoder).await?;
|
||||
|
||||
assert_eq!(response.body_string().await?, BODY);
|
||||
Ok(())
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
mod test_utils;
|
||||
mod server_decode {
|
||||
use async_dup::{Arc, Mutex};
|
||||
use async_std::io::{Cursor, ReadExt};
|
||||
use duplexify::Duplex;
|
||||
use super::test_utils::TestIO;
|
||||
use async_std::io::prelude::*;
|
||||
use http_types::headers::TRANSFER_ENCODING;
|
||||
use http_types::Request;
|
||||
use http_types::Result;
|
||||
|
@ -10,11 +10,12 @@ mod server_decode {
|
|||
|
||||
async fn decode_lines(lines: Vec<&str>) -> Result<Option<Request>> {
|
||||
let s = lines.join("\r\n");
|
||||
async_h1::server::decode(Duplex::new(
|
||||
Arc::new(Mutex::new(Cursor::new(s))),
|
||||
Arc::new(Mutex::new(Cursor::new(vec![]))),
|
||||
))
|
||||
.await
|
||||
let (mut client, server) = TestIO::new();
|
||||
client.write_all(s.as_bytes()).await?;
|
||||
client.close();
|
||||
async_h1::server::decode(server)
|
||||
.await
|
||||
.map(|r| r.map(|(r, _)| r))
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
|
|
|
@ -0,0 +1,254 @@
|
|||
use async_h1::{
|
||||
client::Encoder,
|
||||
server::{ConnectionStatus, Server},
|
||||
};
|
||||
use async_std::io::{Read, Write};
|
||||
use http_types::{Request, Response, Result};
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
future::Future,
|
||||
io,
|
||||
pin::Pin,
|
||||
sync::RwLock,
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
|
||||
use async_dup::Arc;
|
||||
|
||||
#[pin_project::pin_project]
|
||||
pub struct TestServer<F, Fut> {
|
||||
server: Server<TestIO, F, Fut>,
|
||||
#[pin]
|
||||
client: TestIO,
|
||||
}
|
||||
|
||||
impl<F, Fut> TestServer<F, Fut>
|
||||
where
|
||||
F: Fn(Request) -> Fut,
|
||||
Fut: Future<Output = Result<Response>>,
|
||||
{
|
||||
#[allow(dead_code)]
|
||||
pub fn new(f: F) -> Self {
|
||||
let (client, server) = TestIO::new();
|
||||
Self {
|
||||
server: Server::new(server, f),
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn accept_one(&mut self) -> http_types::Result<ConnectionStatus> {
|
||||
self.server.accept_one().await
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn close(&mut self) {
|
||||
self.client.close();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn all_read(&self) -> bool {
|
||||
self.client.all_read()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn write_request(&mut self, request: Request) -> io::Result<()> {
|
||||
async_std::io::copy(&mut Encoder::new(request), self).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, Fut> Read for TestServer<F, Fut>
|
||||
where
|
||||
F: Fn(Request) -> Fut,
|
||||
Fut: Future<Output = Result<Response>>,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.project().client.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, Fut> Write for TestServer<F, Fut>
|
||||
where
|
||||
F: Fn(Request) -> Fut,
|
||||
Fut: Future<Output = Result<Response>>,
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.project().client.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().client.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().client.poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// a Test IO
|
||||
#[derive(Default, Clone, Debug)]
|
||||
pub struct TestIO {
|
||||
pub read: Arc<CloseableCursor>,
|
||||
pub write: Arc<CloseableCursor>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CloseableCursor {
|
||||
data: RwLock<Vec<u8>>,
|
||||
cursor: RwLock<usize>,
|
||||
waker: RwLock<Option<Waker>>,
|
||||
closed: RwLock<bool>,
|
||||
}
|
||||
|
||||
impl CloseableCursor {
|
||||
fn len(&self) -> usize {
|
||||
self.data.read().unwrap().len()
|
||||
}
|
||||
|
||||
fn cursor(&self) -> usize {
|
||||
*self.cursor.read().unwrap()
|
||||
}
|
||||
|
||||
fn current(&self) -> bool {
|
||||
self.len() == self.cursor()
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
std::str::from_utf8(&*self.data.read().unwrap())
|
||||
.unwrap_or("not utf8")
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
fn close(&self) {
|
||||
*self.closed.write().unwrap() = true;
|
||||
}
|
||||
}
|
||||
|
||||
impl TestIO {
|
||||
pub fn new() -> (TestIO, TestIO) {
|
||||
let client = Arc::new(CloseableCursor::default());
|
||||
let server = Arc::new(CloseableCursor::default());
|
||||
|
||||
(
|
||||
TestIO {
|
||||
read: client.clone(),
|
||||
write: server.clone(),
|
||||
},
|
||||
TestIO {
|
||||
read: server,
|
||||
write: client,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn all_read(&self) -> bool {
|
||||
self.write.current()
|
||||
}
|
||||
|
||||
pub fn close(&mut self) {
|
||||
self.write.close();
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for CloseableCursor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("CloseableCursor")
|
||||
.field(
|
||||
"data",
|
||||
&std::str::from_utf8(&self.data.read().unwrap()).unwrap_or("not utf8"),
|
||||
)
|
||||
.field("closed", &*self.closed.read().unwrap())
|
||||
.field("cursor", &*self.cursor.read().unwrap())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for &CloseableCursor {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let len = self.len();
|
||||
let cursor = self.cursor();
|
||||
if cursor < len {
|
||||
let data = &*self.data.read().unwrap();
|
||||
let bytes_to_copy = buf.len().min(len - cursor);
|
||||
buf[..bytes_to_copy].copy_from_slice(&data[cursor..cursor + bytes_to_copy]);
|
||||
*self.cursor.write().unwrap() += bytes_to_copy;
|
||||
Poll::Ready(Ok(bytes_to_copy))
|
||||
} else if *self.closed.read().unwrap() {
|
||||
Poll::Ready(Ok(0))
|
||||
} else {
|
||||
*self.waker.write().unwrap() = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for &CloseableCursor {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
if *self.closed.read().unwrap() {
|
||||
Poll::Ready(Ok(0))
|
||||
} else {
|
||||
self.data.write().unwrap().extend_from_slice(buf);
|
||||
if let Some(waker) = self.waker.write().unwrap().take() {
|
||||
waker.wake();
|
||||
}
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
if let Some(waker) = self.waker.write().unwrap().take() {
|
||||
waker.wake();
|
||||
}
|
||||
*self.closed.write().unwrap() = true;
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for TestIO {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut &*self.read).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for TestIO {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut &*self.write).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut &*self.write).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut &*self.write).poll_close(cx)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue