Skip to content

Commit

Permalink
Merge pull request #399 from tirr-c/route-middleware
Browse files Browse the repository at this point in the history
Per-route middleware
  • Loading branch information
yoshuawuyts authored Feb 9, 2020
2 parents dd9d42d + e4f2f2d commit 589abaf
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 11 deletions.
54 changes: 53 additions & 1 deletion src/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::sync::Arc;

use async_std::future::Future;

use crate::middleware::Next;
use crate::utils::BoxFuture;
use crate::{response::IntoResponse, Request, Response};
use crate::{response::IntoResponse, Middleware, Request, Response};

/// An HTTP request handler.
///
Expand Down Expand Up @@ -63,3 +66,52 @@ where
Box::pin(async move { fut.await.into_response() })
}
}

pub struct MiddlewareEndpoint<E, State> {
endpoint: E,
middleware: Vec<Arc<dyn Middleware<State>>>,
}

impl<E: Clone, State> Clone for MiddlewareEndpoint<E, State> {
fn clone(&self) -> Self {
Self {
endpoint: self.endpoint.clone(),
middleware: self.middleware.clone(),
}
}
}

impl<E, State> std::fmt::Debug for MiddlewareEndpoint<E, State> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
fmt,
"MiddlewareEndpoint (length: {})",
self.middleware.len(),
)
}
}

impl<E, State> MiddlewareEndpoint<E, State>
where
E: Endpoint<State>,
{
pub fn wrap_with_middleware(ep: E, middleware: &[Arc<dyn Middleware<State>>]) -> Self {
Self {
endpoint: ep,
middleware: middleware.to_vec(),
}
}
}

impl<E, State: 'static> Endpoint<State> for MiddlewareEndpoint<E, State>
where
E: Endpoint<State>,
{
fn call<'a>(&'a self, req: Request<State>) -> BoxFuture<'a, Response> {
let next = Next {
endpoint: &self.endpoint,
next_middleware: &self.middleware,
};
next.run(req)
}
}
10 changes: 5 additions & 5 deletions src/router.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use route_recognizer::{Match, Params, Router as MethodRouter};
use std::collections::HashMap;

use crate::endpoint::{DynEndpoint, Endpoint};
use crate::endpoint::DynEndpoint;
use crate::utils::BoxFuture;
use crate::{Request, Response};

Expand Down Expand Up @@ -29,15 +29,15 @@ impl<State: 'static> Router<State> {
}
}

pub(crate) fn add(&mut self, path: &str, method: http::Method, ep: impl Endpoint<State>) {
pub(crate) fn add(&mut self, path: &str, method: http::Method, ep: Box<DynEndpoint<State>>) {
self.method_map
.entry(method)
.or_insert_with(MethodRouter::new)
.add(path, Box::new(ep))
.add(path, ep)
}

pub(crate) fn add_all(&mut self, path: &str, ep: impl Endpoint<State>) {
self.all_method_router.add(path, Box::new(ep))
pub(crate) fn add_all(&mut self, path: &str, ep: Box<DynEndpoint<State>>) {
self.all_method_router.add(path, ep)
}

pub(crate) fn route(&self, path: &str, method: http::Method) -> Selection<'_, State> {
Expand Down
66 changes: 61 additions & 5 deletions src/server/route.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::sync::Arc;

use crate::endpoint::MiddlewareEndpoint;
use crate::utils::BoxFuture;
use crate::{router::Router, Endpoint, Response};
use crate::{router::Router, Endpoint, Middleware, Response};

/// A handle to a route.
///
Expand All @@ -13,6 +16,7 @@ use crate::{router::Router, Endpoint, Response};
pub struct Route<'a, State> {
router: &'a mut Router<State>,
path: String,
middleware: Vec<Arc<dyn Middleware<State>>>,
/// Indicates whether the path of current route is treated as a prefix. Set by
/// [`strip_prefix`].
///
Expand All @@ -25,6 +29,7 @@ impl<'a, State: 'static> Route<'a, State> {
Route {
router,
path,
middleware: Vec::new(),
prefix: false,
}
}
Expand All @@ -44,6 +49,7 @@ impl<'a, State: 'static> Route<'a, State> {
Route {
router: &mut self.router,
path: p,
middleware: self.middleware.clone(),
prefix: false,
}
}
Expand All @@ -60,6 +66,18 @@ impl<'a, State: 'static> Route<'a, State> {
self
}

/// Apply the given middleware to the current route.
pub fn middleware(&mut self, middleware: impl Middleware<State>) -> &mut Self {
self.middleware.push(Arc::new(middleware));
self
}

/// Reset the middleware chain for the current route, if any.
pub fn reset_middleware(&mut self) -> &mut Self {
self.middleware.clear();
self
}

/// Nest a [`Server`] at the current path.
///
/// [`Server`]: struct.Server.html
Expand All @@ -78,10 +96,29 @@ impl<'a, State: 'static> Route<'a, State> {
pub fn method(&mut self, method: http::Method, ep: impl Endpoint<State>) -> &mut Self {
if self.prefix {
let ep = StripPrefixEndpoint::new(ep);
self.router.add(&self.path, method.clone(), ep.clone());
let (ep1, ep2): (Box<dyn Endpoint<_>>, Box<dyn Endpoint<_>>) =
if self.middleware.is_empty() {
let ep = Box::new(ep);
(ep.clone(), ep)
} else {
let ep = Box::new(MiddlewareEndpoint::wrap_with_middleware(
ep,
&self.middleware,
));
(ep.clone(), ep)
};
self.router.add(&self.path, method.clone(), ep1);
let wildcard = self.at("*--tide-path-rest");
wildcard.router.add(&wildcard.path, method, ep);
wildcard.router.add(&wildcard.path, method, ep2);
} else {
let ep: Box<dyn Endpoint<_>> = if self.middleware.is_empty() {
Box::new(ep)
} else {
Box::new(MiddlewareEndpoint::wrap_with_middleware(
ep,
&self.middleware,
))
};
self.router.add(&self.path, method, ep);
}
self
Expand All @@ -93,10 +130,29 @@ impl<'a, State: 'static> Route<'a, State> {
pub fn all(&mut self, ep: impl Endpoint<State>) -> &mut Self {
if self.prefix {
let ep = StripPrefixEndpoint::new(ep);
self.router.add_all(&self.path, ep.clone());
let (ep1, ep2): (Box<dyn Endpoint<_>>, Box<dyn Endpoint<_>>) =
if self.middleware.is_empty() {
let ep = Box::new(ep);
(ep.clone(), ep)
} else {
let ep = Box::new(MiddlewareEndpoint::wrap_with_middleware(
ep,
&self.middleware,
));
(ep.clone(), ep)
};
self.router.add_all(&self.path, ep1);
let wildcard = self.at("*--tide-path-rest");
wildcard.router.add_all(&wildcard.path, ep);
wildcard.router.add_all(&wildcard.path, ep2);
} else {
let ep: Box<dyn Endpoint<_>> = if self.middleware.is_empty() {
Box::new(ep)
} else {
Box::new(MiddlewareEndpoint::wrap_with_middleware(
ep,
&self.middleware,
))
};
self.router.add_all(&self.path, ep);
}
self
Expand Down
150 changes: 150 additions & 0 deletions tests/route_middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use futures::future::BoxFuture;
use http_service::Body;
use http_service_mock::make_server;
use tide::Middleware;

struct TestMiddleware(&'static str, &'static str);

impl TestMiddleware {
fn with_header_name(name: &'static str, value: &'static str) -> Self {
Self(name, value)
}
}

impl<State: Send + Sync + 'static> Middleware<State> for TestMiddleware {
fn handle<'a>(
&'a self,
req: tide::Request<State>,
next: tide::Next<'a, State>,
) -> BoxFuture<'a, tide::Response> {
Box::pin(async move {
let res = next.run(req).await;
res.set_header(self.0, self.1)
})
}
}

async fn echo_path<State>(req: tide::Request<State>) -> String {
req.uri().path().to_string()
}

#[test]
fn route_middleware() {
let mut app = tide::new();
let mut foo_route = app.at("/foo");
foo_route // /foo
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
.get(echo_path);
foo_route
.at("/bar") // nested, /foo/bar
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
.get(echo_path);
foo_route // /foo
.post(echo_path)
.reset_middleware()
.put(echo_path);
let mut server = make_server(app.into_http_service()).unwrap();

let req = http::Request::get("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));

let req = http::Request::post("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));

let req = http::Request::put("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Foo"), None);

let req = http::Request::get("/foo/bar").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
}

#[test]
fn app_and_route_middleware() {
let mut app = tide::new();
app.middleware(TestMiddleware::with_header_name("X-Root", "root"));
app.at("/foo")
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
.get(echo_path);
app.at("/bar")
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
.get(echo_path);
let mut server = make_server(app.into_http_service()).unwrap();

let req = http::Request::get("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
assert_eq!(res.headers().get("X-Bar"), None);

let req = http::Request::get("/bar").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
assert_eq!(res.headers().get("X-Foo"), None);
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
}

#[test]
fn nested_app_with_route_middleware() {
let mut inner = tide::new();
inner.middleware(TestMiddleware::with_header_name("X-Inner", "inner"));
inner
.at("/baz")
.middleware(TestMiddleware::with_header_name("X-Baz", "baz"))
.get(echo_path);

let mut app = tide::new();
app.middleware(TestMiddleware::with_header_name("X-Root", "root"));
app.at("/foo")
.middleware(TestMiddleware::with_header_name("X-Foo", "foo"))
.get(echo_path);
app.at("/bar")
.middleware(TestMiddleware::with_header_name("X-Bar", "bar"))
.nest(inner);
let mut server = make_server(app.into_http_service()).unwrap();

let req = http::Request::get("/foo").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
assert_eq!(res.headers().get("X-Inner"), None);
assert_eq!(res.headers().get("X-Foo"), Some(&"foo".parse().unwrap()));
assert_eq!(res.headers().get("X-Bar"), None);
assert_eq!(res.headers().get("X-Baz"), None);

let req = http::Request::get("/bar/baz").body(Body::empty()).unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Root"), Some(&"root".parse().unwrap()));
assert_eq!(
res.headers().get("X-Inner"),
Some(&"inner".parse().unwrap())
);
assert_eq!(res.headers().get("X-Foo"), None);
assert_eq!(res.headers().get("X-Bar"), Some(&"bar".parse().unwrap()));
assert_eq!(res.headers().get("X-Baz"), Some(&"baz".parse().unwrap()));
}

#[test]
fn subroute_not_nested() {
let mut app = tide::new();
app.at("/parent") // /parent
.middleware(TestMiddleware::with_header_name("X-Parent", "Parent"))
.get(echo_path);
app.at("/parent/child") // /parent/child, not nested
.middleware(TestMiddleware::with_header_name("X-Child", "child"))
.get(echo_path);
let mut server = make_server(app.into_http_service()).unwrap();

let req = http::Request::get("/parent/child")
.body(Body::empty())
.unwrap();
let res = server.simulate(req).unwrap();
assert_eq!(res.headers().get("X-Parent"), None);
assert_eq!(
res.headers().get("X-Child"),
Some(&"child".parse().unwrap())
);
}

0 comments on commit 589abaf

Please sign in to comment.