From 55d0a4cd59850a4eee85e0a56f3c84e0d72229ec Mon Sep 17 00:00:00 2001 From: katelyn martin Date: Fri, 10 Jan 2025 00:00:00 +0000 Subject: [PATCH] review: fix ` as Body>::poll_frame()` see: https://github.com/hyperium/http-body/pull/140#discussion_r1910881095 this commit adds test coverage exposing the bug, and tightens the pattern used to match frames yielded by the data channel. now, when the channel is closed, a `None` will flow onwards and poll the error channel. `None` will be returned when the error channel is closed, which also indicates that the associated `Sender` has been dropped. --- http-body-util/src/channel.rs | 89 +++++++++++++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/http-body-util/src/channel.rs b/http-body-util/src/channel.rs index a918efa..e963d19 100644 --- a/http-body-util/src/channel.rs +++ b/http-body-util/src/channel.rs @@ -48,13 +48,14 @@ where let this = self.project(); match this.rx_frame.poll_recv(cx) { - Poll::Ready(frame) => return Poll::Ready(frame.map(Ok)), - Poll::Pending => {} + Poll::Ready(frame @ Some(_)) => return Poll::Ready(frame.map(Ok)), + Poll::Ready(None) | Poll::Pending => {} } use core::future::Future; match this.rx_error.poll(cx) { - Poll::Ready(err) => return Poll::Ready(err.ok().map(Err)), + Poll::Ready(Ok(error)) => return Poll::Ready(Some(Err(error))), + Poll::Ready(Err(_)) => return Poll::Ready(None), Poll::Pending => {} } @@ -131,13 +132,54 @@ mod tests { use super::*; #[tokio::test] - async fn works() { + async fn empty() { + let (tx, body) = Channel::::new(1024); + drop(tx); + + let collected = body.collect().await.unwrap(); + assert!(collected.trailers().is_none()); + assert!(collected.to_bytes().is_empty()); + } + + #[tokio::test] + async fn can_send_data() { let (mut tx, body) = Channel::::new(1024); tokio::spawn(async move { tx.send_data(Bytes::from("Hel")).await.unwrap(); tx.send_data(Bytes::from("lo!")).await.unwrap(); + }); + + let collected = body.collect().await.unwrap(); + assert!(collected.trailers().is_none()); + assert_eq!(collected.to_bytes(), "Hello!"); + } + + #[tokio::test] + async fn can_send_trailers() { + let (mut tx, body) = Channel::::new(1024); + + tokio::spawn(async move { + let mut trailers = HeaderMap::new(); + trailers.insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("bar"), + ); + tx.send_trailers(trailers).await.unwrap(); + }); + + let collected = body.collect().await.unwrap(); + assert_eq!(collected.trailers().unwrap()["foo"], "bar"); + assert!(collected.to_bytes().is_empty()); + } + + #[tokio::test] + async fn can_send_both_data_and_trailers() { + let (mut tx, body) = Channel::::new(1024); + tokio::spawn(async move { + tx.send_data(Bytes::from("Hel")).await.unwrap(); + tx.send_data(Bytes::from("lo!")).await.unwrap(); let mut trailers = HeaderMap::new(); trailers.insert( HeaderName::from_static("foo"), @@ -150,4 +192,43 @@ mod tests { assert_eq!(collected.trailers().unwrap()["foo"], "bar"); assert_eq!(collected.to_bytes(), "Hello!"); } + + /// A stand-in for an error type, for unit tests. + type Error = &'static str; + /// An example error message. + const MSG: Error = "oh no"; + + #[tokio::test] + async fn aborts_before_trailers() { + let (mut tx, body) = Channel::::new(1024); + + tokio::spawn(async move { + tx.send_data(Bytes::from("Hel")).await.unwrap(); + tx.send_data(Bytes::from("lo!")).await.unwrap(); + tx.abort(MSG); + }); + + let err = body.collect().await.unwrap_err(); + assert_eq!(err, MSG); + } + + #[tokio::test] + async fn aborts_after_trailers() { + let (mut tx, body) = Channel::::new(1024); + + tokio::spawn(async move { + tx.send_data(Bytes::from("Hel")).await.unwrap(); + tx.send_data(Bytes::from("lo!")).await.unwrap(); + let mut trailers = HeaderMap::new(); + trailers.insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("bar"), + ); + tx.send_trailers(trailers).await.unwrap(); + tx.abort(MSG); + }); + + let err = body.collect().await.unwrap_err(); + assert_eq!(err, MSG); + } }