diff --git a/Cargo.toml b/Cargo.toml index 4341d55931..0c31fd1a1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ members = [ "examples/content_types", "examples/ranking", "examples/testing", + "examples/request_state_local_cache", "examples/request_guard", "examples/stream", "examples/json", diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 7c0a1986d7..f6eed6ab86 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -1,3 +1,4 @@ +use std::rc::Rc; use std::cell::{Cell, RefCell}; use std::net::{IpAddr, SocketAddr}; use std::fmt; @@ -26,6 +27,7 @@ struct RequestState<'r> { cookies: RefCell, accept: Storage>, content_type: Storage>, + cache: Rc, } /// The type of an incoming web request. @@ -41,7 +43,7 @@ pub struct Request<'r> { uri: Uri<'r>, headers: HeaderMap<'r>, remote: Option, - state: RequestState<'r> + state: RequestState<'r>, } impl<'r> Request<'r> { @@ -67,6 +69,7 @@ impl<'r> Request<'r> { cookies: RefCell::new(CookieJar::new()), accept: Storage::new(), content_type: Storage::new(), + cache: Rc::new(Container::new()), } } } @@ -78,6 +81,39 @@ impl<'r> Request<'r> { f(&mut request); } + /// Retrieves the cached value for type `T` from the request-local cached + /// state of `self`. If no such value has previously been cached for + /// this request, `f` is called to produce the value which is subsequently + /// returned. + /// + /// # Example + /// + /// ```rust + /// # use rocket::http::Method; + /// # use rocket::Request; + /// # struct User; + /// fn current_user() -> User { + /// // Load user... + /// # User + /// } + /// + /// # Request::example(Method::Get, "/uri", |request| { + /// let user = request.local_cache(current_user); + /// # }); + /// ``` + pub fn local_cache(&self, f: F) -> &T + where T: Send + Sync + 'static, + F: FnOnce() -> T { + + match self.state.cache.try_get() { + Some(cached) => cached, + None => { + self.state.cache.set(f()); + self.state.cache.get() + } + } + } + /// Retrieve the method from `self`. /// /// # Example diff --git a/examples/request_state_local_cache/Cargo.toml b/examples/request_state_local_cache/Cargo.toml new file mode 100644 index 0000000000..0c720023b9 --- /dev/null +++ b/examples/request_state_local_cache/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "request_state_local_cache" +version = "0.0.0" +workspace = "../../" +publish = false + +[dependencies] +rocket = { path = "../../core/lib" } +rocket_codegen = { path = "../../core/codegen" } diff --git a/examples/request_state_local_cache/src/main.rs b/examples/request_state_local_cache/src/main.rs new file mode 100644 index 0000000000..c51cefa1c6 --- /dev/null +++ b/examples/request_state_local_cache/src/main.rs @@ -0,0 +1,55 @@ +#![feature(plugin, decl_macro)] +#![plugin(rocket_codegen)] +extern crate rocket; + +use std::sync::atomic::{AtomicUsize, Ordering}; + +use rocket::request::{self, Request, FromRequest, State}; +use rocket::outcome::Outcome::*; + +#[cfg(test)] mod tests; + +#[derive(Default)] +struct Atomics { + uncached: AtomicUsize, + cached: AtomicUsize, +} + +struct Guard1; +struct Guard2; + +impl<'a, 'r> FromRequest<'a, 'r> for Guard1 { + type Error = (); + + fn from_request(req: &'a Request<'r>) -> request::Outcome { + let atomics = req.guard::>()?; + atomics.uncached.fetch_add(1, Ordering::Relaxed); + req.local_cache(|| atomics.cached.fetch_add(1, Ordering::Relaxed)); + + Success(Guard1) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for Guard2 { + type Error = (); + + fn from_request(req: &'a Request<'r>) -> request::Outcome { + req.guard::()?; + Success(Guard2) + } +} + +#[get("/")] +fn index(_g1: Guard1, _g2: Guard2) { + // This exists only to run the request guards. +} + +fn rocket() -> rocket::Rocket { + rocket::ignite() + .manage(Atomics::default()) + .mount("/", routes!(index)) +} + +fn main() { + rocket().launch(); +} diff --git a/examples/request_state_local_cache/src/tests.rs b/examples/request_state_local_cache/src/tests.rs new file mode 100644 index 0000000000..1e16f7d858 --- /dev/null +++ b/examples/request_state_local_cache/src/tests.rs @@ -0,0 +1,15 @@ +use std::sync::atomic::{Ordering}; + +use ::Atomics; +use super::rocket; +use rocket::local::Client; + +#[test] +fn test() { + let client = Client::new(rocket()).unwrap(); + client.get("/").dispatch(); + + let atomics = client.rocket().state::().unwrap(); + assert_eq!(atomics.uncached.load(Ordering::Relaxed), 2); + assert_eq!(atomics.cached.load(Ordering::Relaxed), 1); +}