Skip to content

Commit

Permalink
Merge pull request #545 from jbr/add-middleware-example
Browse files Browse the repository at this point in the history
allow for function middlewares by dropping Debug bound and add example
  • Loading branch information
yoshuawuyts authored May 28, 2020
2 parents a376bc6 + 8ab1f89 commit eb872dd
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 5 deletions.
101 changes: 101 additions & 0 deletions examples/middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tide::{Middleware, Next, Request, Response, Result, StatusCode};

#[derive(Debug)]
struct User {
name: String,
}

#[derive(Default)]
struct UserDatabase;
impl UserDatabase {
async fn find_user(&self) -> Option<User> {
Some(User {
name: "nori".into(),
})
}
}

// This is an example of a function middleware that uses the
// application state. Because it depends on a specific request state,
// it would likely be closely tied to a specific application
fn user_loader<'a>(
mut request: Request<UserDatabase>,
next: Next<'a, UserDatabase>,
) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>> {
Box::pin(async {
if let Some(user) = request.state().find_user().await {
tide::log::trace!("user loaded", {user: user.name});
request.set_ext(user);
next.run(request).await
// this middleware only needs to run before the endpoint, so
// it just passes through the result of Next
} else {
// do not run endpoints, we could not find a user
Ok(Response::new(StatusCode::Unauthorized))
}
})
}

//
//
// this is an example of middleware that keeps its own state and could
// be provided as a third party crate
#[derive(Default)]
struct RequestCounterMiddleware {
requests_counted: Arc<AtomicUsize>,
}

impl RequestCounterMiddleware {
fn new(start: usize) -> Self {
Self {
requests_counted: Arc::new(AtomicUsize::new(start)),
}
}
}

struct RequestCount(usize);

impl<State: Send + Sync + 'static> Middleware<State> for RequestCounterMiddleware {
fn handle<'a>(
&'a self,
mut req: Request<State>,
next: Next<'a, State>,
) -> Pin<Box<dyn Future<Output = Result> + Send + 'a>> {
Box::pin(async move {
let count = self.requests_counted.fetch_add(1, Ordering::Relaxed);
tide::log::trace!("request counter", { count: count });
req.set_ext(RequestCount(count));

let mut response = next.run(req).await?;

response = response.set_header("request-number", count.to_string());
Ok(response)
})
}
}

#[async_std::main]
async fn main() -> Result<()> {
tide::log::start();
let mut app = tide::with_state(UserDatabase::default());

app.middleware(user_loader);
app.middleware(RequestCounterMiddleware::new(0));

app.at("/").get(|req: Request<_>| async move {
let count: &RequestCount = req.ext().unwrap();
let user: &User = req.ext().unwrap();

Ok(format!(
"Hello {}, this was request number {}!",
user.name, count.0
))
});

app.listen("127.0.0.1:8080").await?;
Ok(())
}
7 changes: 6 additions & 1 deletion src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@ pub trait Middleware<State>: 'static + Send + Sync {
/// Asynchronously handle the request, and return a response.
fn handle<'a>(
&'a self,
cx: Request<State>,
request: Request<State>,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result>;

/// Set the middleware's name. By default it uses the type signature.
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
}

impl<State, F> Middleware<State> for F
Expand Down
6 changes: 2 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ use async_std::prelude::*;
use async_std::sync::Arc;
use async_std::task;

use std::fmt::Debug;

use crate::cookies;
use crate::log;
use crate::middleware::{Middleware, Next};
Expand Down Expand Up @@ -268,9 +266,9 @@ impl<State: Send + Sync + 'static> Server<State> {
/// order in which it is applied.
pub fn middleware<M>(&mut self, middleware: M) -> &mut Self
where
M: Middleware<State> + Debug,
M: Middleware<State>,
{
log::trace!("Adding middleware {:?}", middleware);
log::trace!("Adding middleware {}", middleware.name());
let m = Arc::get_mut(&mut self.middleware)
.expect("Registering middleware is not possible after the Server has started");
m.push(Arc::new(middleware));
Expand Down

0 comments on commit eb872dd

Please sign in to comment.