Skip to content

Commit

Permalink
review: fix <Channel<D, E> as Body>::poll_frame()
Browse files Browse the repository at this point in the history
see: #140 (comment)

this commit adds test coverage exposing the bug, and tightens the
pattern used to match frames yielded by the data channel.

now, when the channel is closed, a `None` will flow onwards and poll the
error channel. `None` will be returned when the error channel is closed,
which also indicates that the associated `Sender` has been dropped.
  • Loading branch information
cratelyn committed Jan 10, 2025
1 parent d9fd52b commit 55d0a4c
Showing 1 changed file with 85 additions and 4 deletions.
89 changes: 85 additions & 4 deletions http-body-util/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ where
let this = self.project();

match this.rx_frame.poll_recv(cx) {
Poll::Ready(frame) => return Poll::Ready(frame.map(Ok)),
Poll::Pending => {}
Poll::Ready(frame @ Some(_)) => return Poll::Ready(frame.map(Ok)),
Poll::Ready(None) | Poll::Pending => {}
}

use core::future::Future;
match this.rx_error.poll(cx) {
Poll::Ready(err) => return Poll::Ready(err.ok().map(Err)),
Poll::Ready(Ok(error)) => return Poll::Ready(Some(Err(error))),
Poll::Ready(Err(_)) => return Poll::Ready(None),
Poll::Pending => {}
}

Expand Down Expand Up @@ -131,13 +132,54 @@ mod tests {
use super::*;

#[tokio::test]
async fn works() {
async fn empty() {
let (tx, body) = Channel::<Bytes>::new(1024);
drop(tx);

let collected = body.collect().await.unwrap();
assert!(collected.trailers().is_none());
assert!(collected.to_bytes().is_empty());
}

#[tokio::test]
async fn can_send_data() {
let (mut tx, body) = Channel::<Bytes>::new(1024);

tokio::spawn(async move {
tx.send_data(Bytes::from("Hel")).await.unwrap();
tx.send_data(Bytes::from("lo!")).await.unwrap();
});

let collected = body.collect().await.unwrap();
assert!(collected.trailers().is_none());
assert_eq!(collected.to_bytes(), "Hello!");
}

#[tokio::test]
async fn can_send_trailers() {
let (mut tx, body) = Channel::<Bytes>::new(1024);

tokio::spawn(async move {
let mut trailers = HeaderMap::new();
trailers.insert(
HeaderName::from_static("foo"),
HeaderValue::from_static("bar"),
);
tx.send_trailers(trailers).await.unwrap();
});

let collected = body.collect().await.unwrap();
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
assert!(collected.to_bytes().is_empty());
}

#[tokio::test]
async fn can_send_both_data_and_trailers() {
let (mut tx, body) = Channel::<Bytes>::new(1024);

tokio::spawn(async move {
tx.send_data(Bytes::from("Hel")).await.unwrap();
tx.send_data(Bytes::from("lo!")).await.unwrap();
let mut trailers = HeaderMap::new();
trailers.insert(
HeaderName::from_static("foo"),
Expand All @@ -150,4 +192,43 @@ mod tests {
assert_eq!(collected.trailers().unwrap()["foo"], "bar");
assert_eq!(collected.to_bytes(), "Hello!");
}

/// A stand-in for an error type, for unit tests.
type Error = &'static str;
/// An example error message.
const MSG: Error = "oh no";

#[tokio::test]
async fn aborts_before_trailers() {
let (mut tx, body) = Channel::<Bytes, Error>::new(1024);

tokio::spawn(async move {
tx.send_data(Bytes::from("Hel")).await.unwrap();
tx.send_data(Bytes::from("lo!")).await.unwrap();
tx.abort(MSG);
});

let err = body.collect().await.unwrap_err();
assert_eq!(err, MSG);
}

#[tokio::test]
async fn aborts_after_trailers() {
let (mut tx, body) = Channel::<Bytes, Error>::new(1024);

tokio::spawn(async move {
tx.send_data(Bytes::from("Hel")).await.unwrap();
tx.send_data(Bytes::from("lo!")).await.unwrap();
let mut trailers = HeaderMap::new();
trailers.insert(
HeaderName::from_static("foo"),
HeaderValue::from_static("bar"),
);
tx.send_trailers(trailers).await.unwrap();
tx.abort(MSG);
});

let err = body.collect().await.unwrap_err();
assert_eq!(err, MSG);
}
}

0 comments on commit 55d0a4c

Please sign in to comment.