keep-alive: always read body to end

This commit is contained in:
Jacob Rothstein 2020-12-13 00:37:30 -08:00
parent 69bed38fd7
commit 7df79f1d5d
No known key found for this signature in database
GPG Key ID: C38BA18C6CFE15A5
12 changed files with 630 additions and 108 deletions

View File

@ -27,12 +27,8 @@ futures-core = "0.3.8"
log = "0.4.11"
pin-project = "1.0.2"
async-channel = "1.5.1"
async-dup = "1.2.2"
[dev-dependencies]
pretty_assertions = "0.6.1"
async-std = { version = "1.7.0", features = ["attributes"] }
tempfile = "3.1.0"
async-test = "1.0.0"
duplexify = "1.2.2"
async-dup = "1.2.2"
async-channel = "1.5.1"

View File

@ -19,7 +19,8 @@ lazy_static::lazy_static! {
/// Decodes a chunked body according to
/// https://tools.ietf.org/html/rfc7230#section-4.1
pub(crate) struct ChunkedDecoder<R: Read> {
#[derive(Debug)]
pub struct ChunkedDecoder<R: Read> {
/// The underlying stream
inner: R,
/// Buffer for the already read, but not yet parsed data.

View File

@ -26,7 +26,7 @@ impl<B> fmt::Debug for ReadNotifier<B> {
}
}
impl<B: BufRead> ReadNotifier<B> {
impl<B: Read> ReadNotifier<B> {
pub(crate) fn new(reader: B, sender: Sender<()>) -> Self {
Self {
reader,

35
src/server/body_reader.rs Normal file
View File

@ -0,0 +1,35 @@
use crate::chunked::ChunkedDecoder;
use async_dup::{Arc, Mutex};
use async_std::io::{BufReader, Read, Take};
use async_std::task::{Context, Poll};
use std::{fmt::Debug, io, pin::Pin};
pub enum BodyReader<IO: Read + Unpin> {
Chunked(Arc<Mutex<ChunkedDecoder<BufReader<IO>>>>),
Fixed(Arc<Mutex<Take<BufReader<IO>>>>),
None,
}
impl<IO: Read + Unpin> Debug for BodyReader<IO> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BodyReader::Chunked(_) => f.write_str("BodyReader::Chunked"),
BodyReader::Fixed(_) => f.write_str("BodyReader::Fixed"),
BodyReader::None => f.write_str("BodyReader::None"),
}
}
}
impl<IO: Read + Unpin> Read for BodyReader<IO> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match &*self {
BodyReader::Chunked(r) => Pin::new(&mut *r.lock()).poll_read(cx, buf),
BodyReader::Fixed(r) => Pin::new(&mut *r.lock()).poll_read(cx, buf),
BodyReader::None => Poll::Ready(Ok(0)),
}
}
}

View File

@ -2,12 +2,14 @@
use std::str::FromStr;
use async_dup::{Arc, Mutex};
use async_std::io::{BufReader, Read, Write};
use async_std::{prelude::*, task};
use http_types::headers::{CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING};
use http_types::{ensure, ensure_eq, format_err};
use http_types::{Body, Method, Request, Url};
use super::body_reader::BodyReader;
use crate::chunked::ChunkedDecoder;
use crate::read_notifier::ReadNotifier;
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
@ -21,7 +23,7 @@ const CONTINUE_HEADER_VALUE: &str = "100-continue";
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
/// Decode an HTTP request on the server.
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<Request>>
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
where
IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
{
@ -108,26 +110,29 @@ where
}
// Check for Transfer-Encoding
if let Some(encoding) = transfer_encoding {
if encoding.last().as_str() == "chunked" {
let trailer_sender = req.send_trailers();
let reader = ChunkedDecoder::new(reader, trailer_sender);
let reader = BufReader::new(reader);
let reader = ReadNotifier::new(reader, body_read_sender);
req.set_body(Body::from_reader(reader, None));
return Ok(Some(req));
}
// Fall through to Content-Length
}
// Check for Content-Length.
if let Some(len) = content_length {
if transfer_encoding
.map(|te| te.as_str().eq_ignore_ascii_case("chunked"))
.unwrap_or(false)
{
let trailer_sender = req.send_trailers();
let reader = ChunkedDecoder::new(reader, trailer_sender);
let reader = Arc::new(Mutex::new(reader));
let reader_clone = reader.clone();
let reader = ReadNotifier::new(reader, body_read_sender);
let reader = BufReader::new(reader);
req.set_body(Body::from_reader(reader, None));
return Ok(Some((req, BodyReader::Chunked(reader_clone))));
} else if let Some(len) = content_length {
let len = len.last().as_str().parse::<usize>()?;
let reader = ReadNotifier::new(reader.take(len as u64), body_read_sender);
req.set_body(Body::from_reader(reader, Some(len)));
let reader = Arc::new(Mutex::new(reader.take(len as u64)));
req.set_body(Body::from_reader(
BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
Some(len),
));
Ok(Some((req, BodyReader::Fixed(reader))))
} else {
Ok(Some((req, BodyReader::None)))
}
Ok(Some(req))
}
fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<Url> {

View File

@ -1,13 +1,12 @@
//! Process HTTP connections on the server.
use std::time::Duration;
use async_std::future::{timeout, Future, TimeoutError};
use async_std::io::{self, Read, Write};
use http_types::headers::{CONNECTION, UPGRADE};
use http_types::upgrade::Connection;
use http_types::{Request, Response, StatusCode};
use std::{marker::PhantomData, time::Duration};
mod body_reader;
mod decode;
mod encode;
@ -38,14 +37,14 @@ where
F: Fn(Request) -> Fut,
Fut: Future<Output = http_types::Result<Response>>,
{
accept_with_opts(io, endpoint, Default::default()).await
Server::new(io, endpoint).accept().await
}
/// Accept a new incoming HTTP/1.1 connection.
///
/// Supports `KeepAlive` requests by default.
pub async fn accept_with_opts<RW, F, Fut>(
mut io: RW,
io: RW,
endpoint: F,
opts: ServerOptions,
) -> http_types::Result<()>
@ -54,35 +53,99 @@ where
F: Fn(Request) -> Fut,
Fut: Future<Output = http_types::Result<Response>>,
{
loop {
// Decode a new request, timing out if this takes longer than the timeout duration.
let fut = decode(io.clone());
Server::new(io, endpoint).with_opts(opts).accept().await
}
let req = if let Some(timeout_duration) = opts.headers_timeout {
/// struct for server
#[derive(Debug)]
pub struct Server<RW, F, Fut> {
io: RW,
endpoint: F,
opts: ServerOptions,
_phantom: PhantomData<Fut>,
}
/// An enum that represents whether the server should accept a subsequent request
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum ConnectionStatus {
/// The server should not accept another request
Close,
/// The server may accept another request
KeepAlive,
}
impl<RW, F, Fut> Server<RW, F, Fut>
where
RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
F: Fn(Request) -> Fut,
Fut: Future<Output = http_types::Result<Response>>,
{
/// builds a new server
pub fn new(io: RW, endpoint: F) -> Self {
Self {
io,
endpoint,
opts: Default::default(),
_phantom: PhantomData,
}
}
/// with opts
pub fn with_opts(mut self, opts: ServerOptions) -> Self {
self.opts = opts;
self
}
/// accept in a loop
pub async fn accept(&mut self) -> http_types::Result<()> {
while ConnectionStatus::KeepAlive == self.accept_one().await? {}
Ok(())
}
/// accept one request
pub async fn accept_one(&mut self) -> http_types::Result<ConnectionStatus>
where
RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
F: Fn(Request) -> Fut,
Fut: Future<Output = http_types::Result<Response>>,
{
// Decode a new request, timing out if this takes longer than the timeout duration.
let fut = decode(self.io.clone());
let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
match timeout(timeout_duration, fut).await {
Ok(Ok(Some(r))) => r,
Ok(Ok(None)) | Err(TimeoutError { .. }) => break, /* EOF or timeout */
Ok(Ok(None)) | Err(TimeoutError { .. }) => return Ok(ConnectionStatus::Close), /* EOF or timeout */
Ok(Err(e)) => return Err(e),
}
} else {
match fut.await? {
Some(r) => r,
None => break, /* EOF */
None => return Ok(ConnectionStatus::Close), /* EOF */
}
};
let has_upgrade_header = req.header(UPGRADE).is_some();
let connection_header_is_upgrade = req
let connection_header_as_str = req
.header(CONNECTION)
.map(|connection| connection.as_str().eq_ignore_ascii_case("upgrade"))
.unwrap_or(false);
.map(|connection| connection.as_str())
.unwrap_or("");
let connection_header_is_upgrade = connection_header_as_str.eq_ignore_ascii_case("upgrade");
let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close");
let upgrade_requested = has_upgrade_header && connection_header_is_upgrade;
let method = req.method();
// Pass the request to the endpoint and encode the response.
let mut res = endpoint(req).await?;
let mut res = (self.endpoint)(req).await?;
close_connection |= res
.header(CONNECTION)
.map(|c| c.as_str().eq_ignore_ascii_case("close"))
.unwrap_or(false);
let upgrade_provided = res.status() == StatusCode::SwitchingProtocols && res.has_upgrade();
@ -94,14 +157,22 @@ where
let mut encoder = Encoder::new(res, method);
// Stream the response to the writer.
io::copy(&mut encoder, &mut io).await?;
let bytes_written = io::copy(&mut encoder, &mut self.io).await?;
log::trace!("wrote {} response bytes", bytes_written);
let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?;
log::trace!(
"discarded {} unread request body bytes",
body_bytes_discarded
);
if let Some(upgrade_sender) = upgrade_sender {
upgrade_sender.send(Connection::new(io.clone())).await;
return Ok(());
upgrade_sender.send(Connection::new(self.io.clone())).await;
return Ok(ConnectionStatus::Close);
} else if close_connection {
Ok(ConnectionStatus::Close)
} else {
Ok(ConnectionStatus::KeepAlive)
}
}
Ok(())
}

181
tests/accept.rs Normal file
View File

@ -0,0 +1,181 @@
mod test_utils;
mod accept {
use super::test_utils::TestServer;
use async_h1::{client::Encoder, server::ConnectionStatus};
use async_std::io::{self, prelude::WriteExt, Cursor};
use http_types::{headers::CONNECTION, Body, Request, Response, Result};
#[async_std::test]
async fn basic() -> Result<()> {
let mut server = TestServer::new(|req| async {
let mut response = Response::new(200);
let len = req.len();
response.set_body(Body::from_reader(req, len));
Ok(response)
});
let content_length = 10;
let request_str = format!(
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
content_length,
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
);
server.write_all(request_str.as_bytes()).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
assert!(server.all_read());
Ok(())
}
#[async_std::test]
async fn request_close() -> Result<()> {
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
server
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nConnection: Close\r\n\r\n")
.await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
assert!(server.all_read());
Ok(())
}
#[async_std::test]
async fn response_close() -> Result<()> {
let mut server = TestServer::new(|_| async {
let mut response = Response::new(200);
response.insert_header(CONNECTION, "close");
Ok(response)
});
server
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
.await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
assert!(server.all_read());
Ok(())
}
#[async_std::test]
async fn keep_alive_short_fixed_length_unread_body() -> Result<()> {
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
let content_length = 10;
let request_str = format!(
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
content_length,
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
);
server.write_all(request_str.as_bytes()).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server.write_all(request_str.as_bytes()).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
assert!(server.all_read());
Ok(())
}
#[async_std::test]
async fn keep_alive_short_chunked_unread_body() -> Result<()> {
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
let content_length = 100;
let mut request = Request::post("http://example.com/");
request.set_body(Body::from_reader(
Cursor::new(vec![b'|'; content_length]),
None,
));
io::copy(&mut Encoder::new(request), &mut server).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server
.write_fmt(format_args!(
"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\n\r\n"
))
.await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
assert!(server.all_read());
Ok(())
}
#[async_std::test]
async fn keep_alive_long_fixed_length_unread_body() -> Result<()> {
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
let content_length = 10000;
let request_str = format!(
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
content_length,
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
);
server.write_all(request_str.as_bytes()).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server.write_all(request_str.as_bytes()).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
assert!(server.all_read());
Ok(())
}
#[async_std::test]
async fn keep_alive_long_chunked_unread_body() -> Result<()> {
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
let content_length = 10000;
let mut request = Request::post("http://example.com/");
request.set_body(Body::from_reader(
Cursor::new(vec![b'|'; content_length]),
None,
));
server.write_request(request).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server
.write_fmt(format_args!(
"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\n\r\n"
))
.await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
assert!(server.all_read());
Ok(())
}
}

View File

@ -94,7 +94,6 @@ mod client_encode {
)
}
#[ignore = "this does not work yet"]
#[async_std::test]
async fn client_encode_chunked_body() -> Result<()> {
let url = Url::parse("http://example.com/path?query").unwrap();
@ -118,6 +117,7 @@ mod client_encode {
"d",
"0",
"",
"",
],
)
.await;
@ -138,6 +138,7 @@ mod client_encode {
"hello world",
"0",
"",
"",
],
)
.await;
@ -169,6 +170,7 @@ mod client_encode {
"t",
"0",
"",
"",
],
)
.await;

View File

@ -1,9 +1,9 @@
use async_dup::{Arc, Mutex};
use async_std::io::{Cursor, SeekFrom};
use async_std::{prelude::*, task};
use duplexify::Duplex;
mod test_utils;
use async_std::{io, prelude::*, task};
use http_types::Result;
use std::time::Duration;
use test_utils::TestIO;
const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
Host: example.com\r\n\
@ -13,63 +13,42 @@ Expect: 100-continue\r\n\r\n";
const SLEEP_DURATION: Duration = std::time::Duration::from_millis(100);
#[async_std::test]
async fn test_with_expect_when_reading_body() -> Result<()> {
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
let server_str: Vec<u8> = vec![];
let (mut client, server) = TestIO::new();
client.write_all(REQUEST_WITH_EXPECT).await?;
let mut client = Arc::new(Mutex::new(Cursor::new(client_str)));
let server = Arc::new(Mutex::new(Cursor::new(server_str)));
let mut request = async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
.await?
.unwrap();
let (mut request, _) = async_h1::server::decode(server).await?.unwrap();
task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written
{
let lock = server.lock();
assert_eq!("", std::str::from_utf8(lock.get_ref())?); //we haven't written yet
};
assert_eq!("", &client.read.to_string()); // we haven't written yet
let mut buf = vec![0u8; 1];
let bytes = request.read(&mut buf).await?; //this triggers the 100-continue even though there's nothing to read yet
assert_eq!(bytes, 0); // normally we'd actually be waiting for the end of the buffer, but this lets us test this sequentially
let join_handle = task::spawn(async move {
let mut string = String::new();
request.read_to_string(&mut string).await?; //this triggers the 100-continue even though there's nothing to read yet
io::Result::Ok(string)
});
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel and io
{
let lock = server.lock();
assert_eq!(
"HTTP/1.1 100 Continue\r\n\r\n",
std::str::from_utf8(lock.get_ref())?
);
};
assert_eq!("HTTP/1.1 100 Continue\r\n\r\n", &client.read.to_string());
client.write_all(b"0123456789").await?;
client
.seek(SeekFrom::Start(REQUEST_WITH_EXPECT.len() as u64))
.await?;
assert_eq!("0123456789", request.body_string().await?);
assert_eq!("0123456789", &join_handle.await?);
Ok(())
}
#[async_std::test]
async fn test_without_expect_when_not_reading_body() -> Result<()> {
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
let server_str: Vec<u8> = vec![];
let (mut client, server) = TestIO::new();
client.write_all(REQUEST_WITH_EXPECT).await?;
let client = Arc::new(Mutex::new(Cursor::new(client_str)));
let server = Arc::new(Mutex::new(Cursor::new(server_str)));
async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
.await?
.unwrap();
let (_, _) = async_h1::server::decode(server).await?.unwrap();
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel
let server_lock = server.lock();
assert_eq!("", std::str::from_utf8(server_lock.get_ref())?); // we haven't written 100-continue
assert_eq!("", &client.read.to_string()); // we haven't written 100-continue
Ok(())
}

View File

@ -1,12 +1,12 @@
use async_dup::{Arc, Mutex};
use async_std::io::Cursor;
use duplexify::Duplex;
use async_h1::{client, server};
use http_types::Body;
use http_types::Method;
use http_types::Request;
use http_types::Url;
use http_types::{Response, Result};
use pretty_assertions::assert_eq;
mod test_utils;
use test_utils::TestIO;
const BODY: &str = concat![
"Et provident reprehenderit accusamus dolores et voluptates sed quia. Repellendus odit porro ut et hic molestiae. Sit autem reiciendis animi fugiat deleniti vel iste. Laborum id odio ullam ut impedit dolores. Vel aperiam dolorem voluptatibus dignissimos maxime.",
@ -72,21 +72,18 @@ async fn server_chunked_large() -> Result<()> {
let mut request = Request::new(Method::Post, Url::parse("http://domain.com").unwrap());
// request.set_body(Body::from_reader(Cursor::new(BODY), None));
request.set_body(Body::from_string(String::from(BODY)));
let request_encoder = async_h1::client::Encoder::new(request);
let request = async_h1::server::decode(Duplex::new(
Arc::new(Mutex::new(request_encoder)),
Arc::new(Mutex::new(Cursor::new(vec![]))),
))
.await?
.unwrap();
let (mut client, server) = TestIO::new();
async_std::io::copy(&mut client::Encoder::new(request), &mut client).await?;
let (request, _) = server::decode(server).await?.unwrap();
let mut response = Response::new(200);
response.set_body(Body::from_reader(request, None));
let response_encoder = async_h1::server::Encoder::new(response, Method::Get);
let response_encoder = server::Encoder::new(response, Method::Get);
let mut response = async_h1::client::decode(response_encoder).await?;
let mut response = client::decode(response_encoder).await?;
assert_eq!(response.body_string().await?, BODY);
Ok(())

View File

@ -1,7 +1,7 @@
mod test_utils;
mod server_decode {
use async_dup::{Arc, Mutex};
use async_std::io::{Cursor, ReadExt};
use duplexify::Duplex;
use super::test_utils::TestIO;
use async_std::io::prelude::*;
use http_types::headers::TRANSFER_ENCODING;
use http_types::Request;
use http_types::Result;
@ -10,11 +10,12 @@ mod server_decode {
async fn decode_lines(lines: Vec<&str>) -> Result<Option<Request>> {
let s = lines.join("\r\n");
async_h1::server::decode(Duplex::new(
Arc::new(Mutex::new(Cursor::new(s))),
Arc::new(Mutex::new(Cursor::new(vec![]))),
))
.await
let (mut client, server) = TestIO::new();
client.write_all(s.as_bytes()).await?;
client.close();
async_h1::server::decode(server)
.await
.map(|r| r.map(|(r, _)| r))
}
#[async_std::test]

254
tests/test_utils.rs Normal file
View File

@ -0,0 +1,254 @@
use async_h1::{
client::Encoder,
server::{ConnectionStatus, Server},
};
use async_std::io::{Read, Write};
use http_types::{Request, Response, Result};
use std::{
fmt::Debug,
future::Future,
io,
pin::Pin,
sync::RwLock,
task::{Context, Poll, Waker},
};
use async_dup::Arc;
#[pin_project::pin_project]
pub struct TestServer<F, Fut> {
server: Server<TestIO, F, Fut>,
#[pin]
client: TestIO,
}
impl<F, Fut> TestServer<F, Fut>
where
F: Fn(Request) -> Fut,
Fut: Future<Output = Result<Response>>,
{
#[allow(dead_code)]
pub fn new(f: F) -> Self {
let (client, server) = TestIO::new();
Self {
server: Server::new(server, f),
client,
}
}
#[allow(dead_code)]
pub async fn accept_one(&mut self) -> http_types::Result<ConnectionStatus> {
self.server.accept_one().await
}
#[allow(dead_code)]
pub fn close(&mut self) {
self.client.close();
}
#[allow(dead_code)]
pub fn all_read(&self) -> bool {
self.client.all_read()
}
#[allow(dead_code)]
pub async fn write_request(&mut self, request: Request) -> io::Result<()> {
async_std::io::copy(&mut Encoder::new(request), self).await?;
Ok(())
}
}
impl<F, Fut> Read for TestServer<F, Fut>
where
F: Fn(Request) -> Fut,
Fut: Future<Output = Result<Response>>,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.project().client.poll_read(cx, buf)
}
}
impl<F, Fut> Write for TestServer<F, Fut>
where
F: Fn(Request) -> Fut,
Fut: Future<Output = Result<Response>>,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().client.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().client.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().client.poll_close(cx)
}
}
/// a Test IO
#[derive(Default, Clone, Debug)]
pub struct TestIO {
pub read: Arc<CloseableCursor>,
pub write: Arc<CloseableCursor>,
}
#[derive(Default)]
pub struct CloseableCursor {
data: RwLock<Vec<u8>>,
cursor: RwLock<usize>,
waker: RwLock<Option<Waker>>,
closed: RwLock<bool>,
}
impl CloseableCursor {
fn len(&self) -> usize {
self.data.read().unwrap().len()
}
fn cursor(&self) -> usize {
*self.cursor.read().unwrap()
}
fn current(&self) -> bool {
self.len() == self.cursor()
}
pub fn to_string(&self) -> String {
std::str::from_utf8(&*self.data.read().unwrap())
.unwrap_or("not utf8")
.to_owned()
}
fn close(&self) {
*self.closed.write().unwrap() = true;
}
}
impl TestIO {
pub fn new() -> (TestIO, TestIO) {
let client = Arc::new(CloseableCursor::default());
let server = Arc::new(CloseableCursor::default());
(
TestIO {
read: client.clone(),
write: server.clone(),
},
TestIO {
read: server,
write: client,
},
)
}
pub fn all_read(&self) -> bool {
self.write.current()
}
pub fn close(&mut self) {
self.write.close();
}
}
impl Debug for CloseableCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CloseableCursor")
.field(
"data",
&std::str::from_utf8(&self.data.read().unwrap()).unwrap_or("not utf8"),
)
.field("closed", &*self.closed.read().unwrap())
.field("cursor", &*self.cursor.read().unwrap())
.finish()
}
}
impl Read for &CloseableCursor {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let len = self.len();
let cursor = self.cursor();
if cursor < len {
let data = &*self.data.read().unwrap();
let bytes_to_copy = buf.len().min(len - cursor);
buf[..bytes_to_copy].copy_from_slice(&data[cursor..cursor + bytes_to_copy]);
*self.cursor.write().unwrap() += bytes_to_copy;
Poll::Ready(Ok(bytes_to_copy))
} else if *self.closed.read().unwrap() {
Poll::Ready(Ok(0))
} else {
*self.waker.write().unwrap() = Some(cx.waker().clone());
Poll::Pending
}
}
}
impl Write for &CloseableCursor {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if *self.closed.read().unwrap() {
Poll::Ready(Ok(0))
} else {
self.data.write().unwrap().extend_from_slice(buf);
if let Some(waker) = self.waker.write().unwrap().take() {
waker.wake();
}
Poll::Ready(Ok(buf.len()))
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if let Some(waker) = self.waker.write().unwrap().take() {
waker.wake();
}
*self.closed.write().unwrap() = true;
Poll::Ready(Ok(()))
}
}
impl Read for TestIO {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &*self.read).poll_read(cx, buf)
}
}
impl Write for TestIO {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &*self.write).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &*self.write).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &*self.write).poll_close(cx)
}
}