Skip to content

Commit

Permalink
Add Body::from_stream (#1848)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed Apr 21, 2023
1 parent 4e4c291 commit 72c1b7a
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 29 deletions.
2 changes: 2 additions & 0 deletions axum-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ futures-util = { version = "0.3", default-features = false, features = ["alloc"]
http = "0.2.7"
http-body = "0.4.5"
mime = "0.3.16"
pin-project-lite = "0.2.7"
sync_wrapper = "0.1.1"
tower-layer = "0.3"
tower-service = "0.3"

Expand Down
61 changes: 61 additions & 0 deletions axum-core/src/body.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
//! HTTP body utilities.

use crate::response::{IntoResponse, Response};
use crate::{BoxError, Error};
use bytes::Bytes;
use bytes::{Buf, BufMut};
use futures_util::stream::Stream;
use futures_util::TryStream;
use http::HeaderMap;
use http_body::Body as _;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
use sync_wrapper::SyncWrapper;

/// A boxed [`Body`] trait object.
///
Expand Down Expand Up @@ -107,6 +111,20 @@ impl Body {
pub fn empty() -> Self {
Self::new(http_body::Empty::new())
}

/// Create a new `Body` from a [`Stream`].
///
/// [`Stream`]: futures_util::stream::Stream
pub fn from_stream<S>(stream: S) -> Self
where
S: TryStream + Send + 'static,
S::Ok: Into<Bytes>,
S::Error: Into<BoxError>,
{
Self::new(StreamBody {
stream: SyncWrapper::new(stream),
})
}
}

impl Default for Body {
Expand Down Expand Up @@ -175,6 +193,49 @@ impl Stream for Body {
}
}

pin_project! {
struct StreamBody<S> {
#[pin]
stream: SyncWrapper<S>,
}
}

impl<S> http_body::Body for StreamBody<S>
where
S: TryStream,
S::Ok: Into<Bytes>,
S::Error: Into<BoxError>,
{
type Data = Bytes;
type Error = Error;

fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let stream = self.project().stream.get_pin_mut();
match futures_util::ready!(stream.try_poll_next(cx)) {
Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk.into()))),
Some(Err(err)) => Poll::Ready(Some(Err(Error::new(err)))),
None => Poll::Ready(None),
}
}

#[inline]
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}

impl IntoResponse for Body {
fn into_response(self) -> Response {
Response::new(self.0)
}
}

#[test]
fn test_try_downcast() {
assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
Expand Down
30 changes: 12 additions & 18 deletions axum-extra/src/body/async_read_body.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use axum::{
body::{self, Bytes, HttpBody, StreamBody},
body::{Body, Bytes, HttpBody},
http::HeaderMap,
response::{IntoResponse, Response},
Error,
Expand Down Expand Up @@ -47,51 +47,45 @@ pin_project! {
#[cfg(feature = "async-read-body")]
#[derive(Debug)]
#[must_use]
pub struct AsyncReadBody<R> {
pub struct AsyncReadBody {
#[pin]
read: StreamBody<ReaderStream<R>>,
body: Body,
}
}

impl<R> AsyncReadBody<R> {
impl AsyncReadBody {
/// Create a new `AsyncReadBody`.
pub fn new(read: R) -> Self
pub fn new<R>(read: R) -> Self
where
R: AsyncRead + Send + 'static,
{
Self {
read: StreamBody::new(ReaderStream::new(read)),
body: Body::from_stream(ReaderStream::new(read)),
}
}
}

impl<R> HttpBody for AsyncReadBody<R>
where
R: AsyncRead + Send + 'static,
{
impl HttpBody for AsyncReadBody {
type Data = Bytes;
type Error = Error;

fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.project().read.poll_data(cx)
self.project().body.poll_data(cx)
}

fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
self.project().body.poll_trailers(cx)
}
}

impl<R> IntoResponse for AsyncReadBody<R>
where
R: AsyncRead + Send + 'static,
{
impl IntoResponse for AsyncReadBody {
fn into_response(self) -> Response {
Response::new(body::boxed(self))
self.body.into_response()
}
}
6 changes: 2 additions & 4 deletions axum-extra/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use axum::{
body::{Body, Bytes},
extract::FromRequest,
response::{IntoResponse, Response},
BoxError, RequestExt,
RequestExt,
};
use futures_util::stream::Stream;
use http::{
Expand Down Expand Up @@ -410,9 +410,7 @@ impl std::error::Error for InvalidBoundary {}
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{
body::Body, extract::DefaultBodyLimit, response::IntoResponse, routing::post, Router,
};
use axum::{extract::DefaultBodyLimit, response::IntoResponse, routing::post, Router};

#[tokio::test]
async fn content_type_with_encoding() {
Expand Down
4 changes: 2 additions & 2 deletions axum-extra/src/json_lines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use axum::{
async_trait,
body::{Body, StreamBody},
body::Body,
extract::FromRequest,
response::{IntoResponse, Response},
BoxError,
Expand Down Expand Up @@ -166,7 +166,7 @@ where
buf.write_all(b"\n")?;
Ok::<_, BoxError>(buf.into_inner().freeze())
});
let stream = StreamBody::new(stream);
let stream = Body::from_stream(stream);

// there is no consensus around mime type yet
// https://github.com/wardi/jsonlines/issues/36
Expand Down
4 changes: 0 additions & 4 deletions axum/src/body/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
//! HTTP body utilities.

mod stream_body;

pub use self::stream_body::StreamBody;

#[doc(no_inline)]
pub use http_body::{Body as HttpBody, Empty, Full};

Expand Down
1 change: 0 additions & 1 deletion axum/src/test_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ pub(crate) mod tracing_helpers;

pub(crate) fn assert_send<T: Send>() {}
pub(crate) fn assert_sync<T: Sync>() {}
pub(crate) fn assert_unpin<T: Unpin>() {}

pub(crate) struct NotSendSync(*const ());

0 comments on commit 72c1b7a

Please sign in to comment.