Skip to content

Commit

Permalink
Add Middlware::combine method to combine two middlewares.
Browse files Browse the repository at this point in the history
  • Loading branch information
sunli829 committed Oct 21, 2024
1 parent 0cbbabc commit cf762d9
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
4 changes: 4 additions & 0 deletions poem/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

# [3.1.2] 2024-10-21

- Add `Middlware::combine` method to combine two middlewares.

# [3.1.1] 2024-10-02

- Add `WebSocket::config` method to set the WebSocket configuration.
Expand Down
78 changes: 78 additions & 0 deletions poem/src/middleware/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ mod tokio_metrics_mw;
mod tower_compat;
mod tracing_mw;

use std::marker::PhantomData;

#[cfg(feature = "compression")]
pub use self::compression::{Compression, CompressionEndpoint};
#[cfg(feature = "cookie")]
Expand Down Expand Up @@ -179,6 +181,82 @@ pub trait Middleware<E: Endpoint> {

/// Transform the input [`Endpoint`] to another one.
fn transform(&self, ep: E) -> Self::Output;

/// Create a new middleware by combining two middlewares.
///
/// # Example
///
/// ```
/// use poem::{
/// handler, middleware::SetHeader, Endpoint, EndpointExt, Middleware, Request, Result,
/// };
///
/// #[handler]
/// fn index() -> &'static str {
/// "hello"
/// }
///
/// #[tokio::main]
/// async fn main() -> Result<(), std::io::Error> {
/// let ep = index.with(
/// SetHeader::new()
/// .appending("myheader", "a")
/// .combine(SetHeader::new().appending("myheader", "b")),
/// );
///
/// let resp = ep.call(Request::default()).await.unwrap();
/// assert_eq!(
/// resp.headers()
/// .get_all("myheader")
/// .iter()
/// .flat_map(|value| value.to_str().ok())
/// .collect::<Vec<_>>(),
/// vec!["a", "b"]
/// );
/// Ok(())
/// }
/// ```
fn combine<T>(self, other: T) -> CombineMiddleware<Self, T, E>
where
T: Middleware<Self::Output> + Sized,
Self: Sized,
{
CombineMiddleware {
a: self,
b: other,
_mark: PhantomData,
}
}
}

impl<E: Endpoint> Middleware<E> for () {
type Output = E;

#[inline]
fn transform(&self, ep: E) -> Self::Output {
ep
}
}

/// A middleware that combines two middlewares.
pub struct CombineMiddleware<A, B, E> {
a: A,
b: B,
_mark: PhantomData<E>,
}

impl<A, B, E> Middleware<E> for CombineMiddleware<A, B, E>
where
A: Middleware<E>,
B: Middleware<A::Output>,
E: Endpoint,
{
type Output = B::Output;

#[inline]
fn transform(&self, ep: E) -> Self::Output {
self.b.transform(self.a.transform(ep))
}
}

poem_derive::generate_implement_middlewares!();
Expand Down

0 comments on commit cf762d9

Please sign in to comment.