diff --git a/examples/examples/tower.rs b/examples/examples/tower.rs index e789393..33b49d5 100644 --- a/examples/examples/tower.rs +++ b/examples/examples/tower.rs @@ -1,4 +1,6 @@ use hitbox_stretto::StrettoBackend; +use hitbox_redis::RedisBackend; +use hitbox_stretto::builder::StrettoBackendBuilder; use hitbox_tower::Cache; use hyper::{Body, Server}; use std::{convert::Infallible, net::SocketAddr}; @@ -18,10 +20,12 @@ async fn main() { .finish(); tracing::subscriber::set_global_default(subscriber).unwrap(); - let inmemory = StrettoBackend::builder(2 ^ 16).finalize().unwrap(); + let inmemory = StrettoBackend::builder(10_000_000).finalize().unwrap(); + let redis = RedisBackend::builder().build().unwrap(); + let service = tower::ServiceBuilder::new() - .layer(tower_http::trace::TraceLayer::new_for_http()) .layer(Cache::builder().backend(inmemory).build()) + .layer(Cache::builder().backend(redis).build()) .service_fn(handle); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/hitbox-http/Cargo.toml b/hitbox-http/Cargo.toml index e0a8909..0f5c67f 100644 --- a/hitbox-http/Cargo.toml +++ b/hitbox-http/Cargo.toml @@ -19,13 +19,13 @@ http = "0.2" http-body = "0.4" hitbox = { path = "../hitbox", version = "0.1" } hitbox-backend = { path = "../hitbox-backend", version = "0.1" } -serde = "1.0.144" bytes = "1" chrono = "0.4" hyper = { version = "0.14", features = ["stream"] } futures = { version = "0.3", default-features = false } actix-router = "0.5" serde_qs = "0.12" +serde = { version = "1", features = ["derive"] } [dev-dependencies] tokio = { version = "1", features = ["test-util"], default-features = false } diff --git a/hitbox-http/src/extractors/header.rs b/hitbox-http/src/extractors/header.rs new file mode 100644 index 0000000..1ed87ce --- /dev/null +++ b/hitbox-http/src/extractors/header.rs @@ -0,0 +1,50 @@ +use async_trait::async_trait; +use hitbox::cache::{Extractor, KeyPart, KeyParts}; +use http::HeaderValue; + +use crate::CacheableHttpRequest; + +pub struct Header { + inner: E, + name: String, +} + +pub trait HeaderExtractor: Sized { + fn header(self, name: String) -> Header; +} + +impl HeaderExtractor for E +where + E: Extractor, +{ + fn header(self, name: String) -> Header { + Header { inner: self, name } + } +} + +#[async_trait] +impl Extractor for Header +where + ReqBody: Send + 'static, + E: Extractor> + Send + Sync, +{ + type Subject = E::Subject; + + async fn get(&self, subject: Self::Subject) -> KeyParts { + let value = subject + .parts() + .headers + .get(self.name.as_str()) + .map(HeaderValue::to_str) + .transpose() + .ok() + .flatten() + .map(str::to_string); + let mut parts = self.inner.get(subject).await; + parts.push(KeyPart { + key: self.name.clone(), + value, + }); + parts + } +} diff --git a/hitbox-http/src/extractors/method.rs b/hitbox-http/src/extractors/method.rs new file mode 100644 index 0000000..ac17416 --- /dev/null +++ b/hitbox-http/src/extractors/method.rs @@ -0,0 +1,41 @@ +use async_trait::async_trait; +use hitbox::cache::{Extractor, KeyPart, KeyParts}; +use http::HeaderValue; + +use crate::CacheableHttpRequest; + +pub struct Method { + inner: E, +} + +pub trait MethodExtractor: Sized { + fn method(self) -> Method; +} + +impl MethodExtractor for E +where + E: Extractor, +{ + fn method(self) -> Method { + Method { inner: self } + } +} + +#[async_trait] +impl Extractor for Method +where + ReqBody: Send + 'static, + E: Extractor> + Send + Sync, +{ + type Subject = E::Subject; + + async fn get(&self, subject: Self::Subject) -> KeyParts { + let method = subject.parts().method.to_string(); + let mut parts = self.inner.get(subject).await; + parts.push(KeyPart { + key: "method".to_owned(), + value: Some(method), + }); + parts + } +} diff --git a/hitbox-http/src/extractors/mod.rs b/hitbox-http/src/extractors/mod.rs new file mode 100644 index 0000000..b05f2ac --- /dev/null +++ b/hitbox-http/src/extractors/mod.rs @@ -0,0 +1,36 @@ +use std::marker::PhantomData; + +use async_trait::async_trait; +use hitbox::cache::{Extractor, KeyPart, KeyParts}; + +use crate::CacheableHttpRequest; + +pub mod header; +pub mod method; +pub mod path; +pub mod query; + +pub struct NeutralExtractor { + _res: PhantomData ReqBody>, +} + +impl NeutralExtractor { + pub fn new() -> Self { + NeutralExtractor { _res: PhantomData } + } +} + +#[async_trait] +impl Extractor for NeutralExtractor +where + ResBody: Send + 'static, +{ + type Subject = CacheableHttpRequest; + + async fn get(&self, subject: Self::Subject) -> KeyParts { + KeyParts { + subject, + parts: Vec::new(), + } + } +} diff --git a/hitbox-http/src/extractors/path.rs b/hitbox-http/src/extractors/path.rs new file mode 100644 index 0000000..bc8ac76 --- /dev/null +++ b/hitbox-http/src/extractors/path.rs @@ -0,0 +1,51 @@ +use actix_router::{ResourceDef, ResourcePath}; +use async_trait::async_trait; +use hitbox::cache::{CacheableRequest, Extractor, KeyPart, KeyParts}; +use http::HeaderValue; + +use crate::CacheableHttpRequest; + +pub struct Path { + inner: E, + resource: ResourceDef, +} + +pub trait PathExtractor: Sized { + fn path(self, resource: &str) -> Path; +} + +impl PathExtractor for E +where + E: Extractor, +{ + fn path(self, resource: &str) -> Path { + Path { + inner: self, + resource: ResourceDef::try_from(resource).unwrap(), + } + } +} + +#[async_trait] +impl Extractor for Path +where + ReqBody: Send + 'static, + E: Extractor> + Send + Sync, +{ + type Subject = E::Subject; + + async fn get(&self, subject: Self::Subject) -> KeyParts { + let mut path = actix_router::Path::new(subject.parts().uri.path()); + self.resource.capture_match_info(&mut path); + let mut matched_parts = path + .iter() + .map(|(key, value)| KeyPart { + key: key.to_owned(), + value: Some(value.to_owned()), + }) + .collect::>(); + let mut parts = self.inner.get(subject).await; + parts.append(&mut matched_parts); + parts + } +} diff --git a/hitbox-http/src/extractors/query.rs b/hitbox-http/src/extractors/query.rs new file mode 100644 index 0000000..2d2ae0f --- /dev/null +++ b/hitbox-http/src/extractors/query.rs @@ -0,0 +1,50 @@ +use async_trait::async_trait; +use hitbox::cache::{Extractor, KeyPart, KeyParts}; + +use crate::CacheableHttpRequest; + +pub struct Query { + inner: E, + name: String, +} + +pub trait QueryExtractor: Sized { + fn query(self, name: String) -> Query; +} + +impl QueryExtractor for E +where + E: Extractor, +{ + fn query(self, name: String) -> Query { + Query { inner: self, name } + } +} + +#[async_trait] +impl Extractor for Query +where + ReqBody: Send + 'static, + E: Extractor> + Send + Sync, +{ + type Subject = E::Subject; + + async fn get(&self, subject: Self::Subject) -> KeyParts { + let values = subject + .parts() + .uri + .query() + .map(crate::query::parse) + .map(|m| m.get(&self.name).map(crate::query::Value::inner)) + .flatten() + .unwrap_or_default(); + let mut parts = self.inner.get(subject).await; + for value in values { + parts.push(KeyPart { + key: self.name.clone(), + value: Some(value), + }); + } + parts + } +} diff --git a/hitbox-http/src/lib.rs b/hitbox-http/src/lib.rs index a0612fa..3403d5c 100644 --- a/hitbox-http/src/lib.rs +++ b/hitbox-http/src/lib.rs @@ -1,5 +1,7 @@ mod body; +pub mod extractors; pub mod predicates; +mod query; mod request; mod response; diff --git a/hitbox-http/src/predicates/mod.rs b/hitbox-http/src/predicates/mod.rs index e7e09ac..51634fa 100644 --- a/hitbox-http/src/predicates/mod.rs +++ b/hitbox-http/src/predicates/mod.rs @@ -36,7 +36,7 @@ where } pub struct NeutralResponsePredicate { - _res: PhantomData ResBody>, // FIX: HEHE + _res: PhantomData ResBody>, // FIX: HEHE } impl NeutralResponsePredicate { diff --git a/hitbox-http/src/predicates/query.rs b/hitbox-http/src/predicates/query.rs index ee6c066..73b91e4 100644 --- a/hitbox-http/src/predicates/query.rs +++ b/hitbox-http/src/predicates/query.rs @@ -1,19 +1,10 @@ use crate::CacheableHttpRequest; use async_trait::async_trait; use hitbox::predicates::{Operation, Predicate, PredicateResult}; -use serde::Deserialize; -use std::{collections::HashMap, marker::PhantomData}; - -#[derive(Deserialize, PartialEq, Eq)] -#[serde(untagged)] -pub enum QsValue { - Scalar(String), - Array(Vec), -} pub struct Query

{ pub name: String, - pub value: QsValue, + pub value: crate::query::Value, pub operation: Operation, inner: P, } @@ -29,17 +20,13 @@ where fn query(self, name: String, value: String) -> Query

{ Query { name, - value: QsValue::Scalar(value), + value: crate::query::Value::Scalar(value), operation: Operation::Eq, inner: self, } } } -fn parse_query(value: &str) -> HashMap { - serde_qs::from_str(value).unwrap() -} - #[async_trait] impl Predicate for Query

where @@ -52,11 +39,11 @@ where match self.inner.check(request).await { PredicateResult::Cacheable(request) => { let op = match self.operation { - Operation::Eq => QsValue::eq, + Operation::Eq => crate::query::Value::eq, Operation::In => unimplemented!(), }; match request.parts().uri.query() { - Some(query_string) => match parse_query(query_string).get(&self.name) { + Some(query_string) => match crate::query::parse(query_string).get(&self.name) { Some(value) if op(value, &self.value) => { PredicateResult::Cacheable(request) } diff --git a/hitbox-http/src/query.rs b/hitbox-http/src/query.rs new file mode 100644 index 0000000..857700e --- /dev/null +++ b/hitbox-http/src/query.rs @@ -0,0 +1,51 @@ +use serde::Deserialize; +use std::collections::HashMap; + +#[derive(Debug, Deserialize, PartialEq, Eq)] +#[serde(untagged)] +pub enum Value { + Scalar(String), + Array(Vec), +} + +impl Value { + pub fn inner(&self) -> Vec { + match self { + Value::Scalar(value) => vec![value.to_owned()], + Value::Array(values) => values.to_owned(), + } + } +} + +pub fn parse(value: &str) -> HashMap { + serde_qs::from_str(value).expect("Unreachable branch reached") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_valid_one() { + let hash_map = parse("key=value"); + let value = hash_map.get("key").unwrap(); + assert_eq!(value.inner(), vec!["value"]); + } + + #[test] + fn test_parse_valid_multiple() { + let hash_map = parse("key-one=value-one&key-two=value-two&key-three=value-three"); + let value = hash_map.get("key-one").unwrap(); + assert_eq!(value.inner(), vec!["value-one"]); + let value = hash_map.get("key-two").unwrap(); + assert_eq!(value.inner(), vec!["value-two"]); + let value = hash_map.get("key-three").unwrap(); + assert_eq!(value.inner(), vec!["value-three"]); + } + + #[test] + fn test_parse_not_valid() { + let hash_map = parse(" wrong "); + assert_eq!(hash_map.len(), 1); + } +} diff --git a/hitbox-http/src/request.rs b/hitbox-http/src/request.rs index 0c447a0..28ca3d3 100644 --- a/hitbox-http/src/request.rs +++ b/hitbox-http/src/request.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use futures::{stream, StreamExt}; use hitbox::{ - cache::{CacheKey, CachePolicy, CacheableRequest, Selector}, + cache::{CacheKey, CachePolicy, CacheableRequest, Extractor}, predicates::{Predicate, PredicateResult}, Cacheable, }; @@ -41,13 +41,20 @@ impl CacheableRequest for CacheableHttpRequest where ReqBody: Send + 'static, { - async fn cache_policy

(self, predicates: P) -> hitbox::cache::CachePolicy + async fn cache_policy( + self, + predicates: P, + extractors: E, + ) -> hitbox::cache::CachePolicy where P: Predicate + Send + Sync, + E: Extractor + Send + Sync, { dbg!("CacheableHttpRequest::cache_policy"); - match predicates.check(self).await { - PredicateResult::Cacheable(request) => CachePolicy::Cacheable(request), + let (request, key) = extractors.get(self).await.into_cache_key(); + + match predicates.check(request).await { + PredicateResult::Cacheable(request) => CachePolicy::Cacheable { key, request }, PredicateResult::NonCacheable(request) => CachePolicy::NonCacheable(request), } } diff --git a/hitbox-http/tests/cache_policy/mod.rs b/hitbox-http/tests/cache_policy/mod.rs index 335001c..bd0c0d5 100644 --- a/hitbox-http/tests/cache_policy/mod.rs +++ b/hitbox-http/tests/cache_policy/mod.rs @@ -1 +1 @@ -mod request; +//mod request; diff --git a/hitbox-http/tests/extractors/header.rs b/hitbox-http/tests/extractors/header.rs new file mode 100644 index 0000000..27ad168 --- /dev/null +++ b/hitbox-http/tests/extractors/header.rs @@ -0,0 +1,21 @@ +use hitbox::cache::Extractor; +use hitbox_http::extractors::{header::HeaderExtractor, NeutralExtractor}; +use hitbox_http::CacheableHttpRequest; +use http::Request; +use hyper::Body; + +#[tokio::test] +async fn test_request_header_extractor_some() { + let request = Request::builder() + .header("x-test", "test-value") + .body(Body::empty()) + .unwrap(); + let request = CacheableHttpRequest::from_request(request); + let extractor = NeutralExtractor::new().header("x-test".to_owned()); + let parts = extractor.get(request).await; + dbg!(parts); + // assert!(matches!( + // prediction, + // hitbox::predicates::PredicateResult::Cacheable(_) + // )); +} diff --git a/hitbox-http/tests/extractors/method.rs b/hitbox-http/tests/extractors/method.rs new file mode 100644 index 0000000..ebe18fe --- /dev/null +++ b/hitbox-http/tests/extractors/method.rs @@ -0,0 +1,18 @@ +use hitbox::cache::Extractor; +use hitbox_http::extractors::{method::MethodExtractor, NeutralExtractor}; +use hitbox_http::CacheableHttpRequest; +use http::{Method, Request}; +use hyper::Body; + +#[tokio::test] +async fn test_request_method_extractor_some() { + let request = Request::builder() + .uri("/users/42/books/24/") + .method(Method::POST) + .body(Body::empty()) + .unwrap(); + let request = CacheableHttpRequest::from_request(request); + let extractor = NeutralExtractor::new().method(); + let parts = extractor.get(request).await; + dbg!(parts); +} diff --git a/hitbox-http/tests/extractors/mod.rs b/hitbox-http/tests/extractors/mod.rs new file mode 100644 index 0000000..763857c --- /dev/null +++ b/hitbox-http/tests/extractors/mod.rs @@ -0,0 +1,5 @@ +mod header; +mod method; +mod multiple; +mod path; +mod query; diff --git a/hitbox-http/tests/extractors/multiple.rs b/hitbox-http/tests/extractors/multiple.rs new file mode 100644 index 0000000..1e9679c --- /dev/null +++ b/hitbox-http/tests/extractors/multiple.rs @@ -0,0 +1,24 @@ +use hitbox::cache::Extractor; +use hitbox_http::extractors::{ + header::HeaderExtractor, method::MethodExtractor, path::PathExtractor, NeutralExtractor, +}; +use hitbox_http::CacheableHttpRequest; +use http::{Method, Request}; +use hyper::Body; + +#[tokio::test] +async fn test_request_multiple_extractor_some() { + let request = Request::builder() + .uri("/users/42/books/24/") + .method(Method::PUT) + .header("X-test", "x-test-value") + .body(Body::empty()) + .unwrap(); + let request = CacheableHttpRequest::from_request(request); + let extractor = NeutralExtractor::new() + .path("/users/{user_id}/books/{book_id}/") + .method() + .header("x-test".to_owned()); + let parts = extractor.get(request).await; + dbg!(parts); +} diff --git a/hitbox-http/tests/extractors/path.rs b/hitbox-http/tests/extractors/path.rs new file mode 100644 index 0000000..5d3d874 --- /dev/null +++ b/hitbox-http/tests/extractors/path.rs @@ -0,0 +1,17 @@ +use hitbox::cache::Extractor; +use hitbox_http::extractors::{path::PathExtractor, NeutralExtractor}; +use hitbox_http::CacheableHttpRequest; +use http::Request; +use hyper::Body; + +#[tokio::test] +async fn test_request_path_extractor_some() { + let request = Request::builder() + .uri("/users/42/books/24/") + .body(Body::empty()) + .unwrap(); + let request = CacheableHttpRequest::from_request(request); + let extractor = NeutralExtractor::new().path("/users/{user_id}/books/{book_id}/"); + let parts = extractor.get(request).await; + dbg!(parts); +} diff --git a/hitbox-http/tests/extractors/query.rs b/hitbox-http/tests/extractors/query.rs new file mode 100644 index 0000000..d1d9c23 --- /dev/null +++ b/hitbox-http/tests/extractors/query.rs @@ -0,0 +1,44 @@ +use hitbox::cache::Extractor; +use hitbox_http::extractors::{query::QueryExtractor, NeutralExtractor}; +use hitbox_http::CacheableHttpRequest; +use http::Request; +use hyper::Body; + +#[tokio::test] +async fn test_request_query_extractor_some() { + let uri = http::uri::Uri::builder() + .path_and_query("test-path?key=value") + .build() + .unwrap(); + let request = Request::builder().uri(uri).body(Body::empty()).unwrap(); + let request = CacheableHttpRequest::from_request(request); + let extractor = NeutralExtractor::new().query("key".to_owned()); + let parts = extractor.get(request).await; + dbg!(parts); +} + +#[tokio::test] +async fn test_request_query_extractor_none() { + let uri = http::uri::Uri::builder() + .path_and_query("test-path?key=value") + .build() + .unwrap(); + let request = Request::builder().uri(uri).body(Body::empty()).unwrap(); + let request = CacheableHttpRequest::from_request(request); + let extractor = NeutralExtractor::new().query("non-existent-key".to_owned()); + let parts = extractor.get(request).await; + dbg!(parts); +} + +#[tokio::test] +async fn test_request_query_extractor_multiple() { + let uri = http::uri::Uri::builder() + .path_and_query("test-path?cars[]=Saab&cars[]=Audi") + .build() + .unwrap(); + let request = Request::builder().uri(uri).body(Body::empty()).unwrap(); + let request = CacheableHttpRequest::from_request(request); + let extractor = NeutralExtractor::new().query("cars".to_owned()); + let parts = extractor.get(request).await; + dbg!(parts); +} diff --git a/hitbox-http/tests/mod.rs b/hitbox-http/tests/mod.rs index ab514a3..20be204 100644 --- a/hitbox-http/tests/mod.rs +++ b/hitbox-http/tests/mod.rs @@ -1,2 +1,3 @@ mod cache_policy; +mod extractors; mod predicates; diff --git a/hitbox-http/tests/predicates/request/query.rs b/hitbox-http/tests/predicates/request/query.rs index f6c877f..e516752 100644 --- a/hitbox-http/tests/predicates/request/query.rs +++ b/hitbox-http/tests/predicates/request/query.rs @@ -1,5 +1,5 @@ -use hitbox::predicates::{Operation, Predicate}; -use hitbox_http::predicates::query::{QsValue, QueryPredicate}; +use hitbox::predicates::Predicate; +use hitbox_http::predicates::query::QueryPredicate; use hitbox_http::predicates::NeutralPredicate; use hitbox_http::CacheableHttpRequest; use http::Request; diff --git a/hitbox-tower/src/service.rs b/hitbox-tower/src/service.rs index 9941f29..89e9d87 100644 --- a/hitbox-tower/src/service.rs +++ b/hitbox-tower/src/service.rs @@ -13,6 +13,8 @@ use hitbox::{ }; use hitbox_backend::CacheableResponse; use hitbox_http::{ + extractors::NeutralExtractor, + extractors::{method::MethodExtractor, path::PathExtractor}, predicates::{query::QueryPredicate, NeutralPredicate, NeutralResponsePredicate}, CacheableHttpRequest, CacheableHttpResponse, FromBytes, SerializableHttpResponse, }; @@ -87,8 +89,11 @@ where self.backend.clone(), CacheableHttpRequest::from_request(req), transformer, - Arc::new(NeutralPredicate::new().query("cache".to_owned(), "true".to_owned())), + Arc::new(Box::new( + NeutralPredicate::new().query("cache".to_owned(), "true".to_owned()), + )), Arc::new(NeutralResponsePredicate::new()), + Arc::new(NeutralExtractor::new().method().path("/{path}*")), ) } } diff --git a/hitbox/src/cache.rs b/hitbox/src/cache.rs index c8addb8..ed74239 100644 --- a/hitbox/src/cache.rs +++ b/hitbox/src/cache.rs @@ -1,5 +1,7 @@ //! Cacheable trait and implementation of cache logic. +use std::sync::Arc; + use crate::{predicates::Predicate, CacheError}; use async_trait::async_trait; #[cfg(feature = "derive")] @@ -79,24 +81,96 @@ pub trait Cacheable { #[derive(Debug)] pub enum CachePolicy { - Cacheable(T), + Cacheable { key: CacheKey, request: T }, NonCacheable(T), } +#[derive(Debug)] pub struct CacheKey { - pub key: String, + pub parts: Vec, pub version: u32, pub prefix: String, } -pub struct SelectorPart(T, String); +impl CacheKey { + pub fn serialize(&self) -> String { + self.parts + .iter() + .map(|part| { + format!( + "{}:{}", + part.key, + part.value.clone().unwrap_or("None".to_owned()) + ) + }) + .collect::>() + .join("::") + } +} + +#[derive(Debug)] +pub struct KeyPart { + pub key: String, + pub value: Option, +} + +#[derive(Debug)] +pub struct KeyParts { + pub subject: T, + pub parts: Vec, +} + +impl KeyParts { + pub fn push(&mut self, part: KeyPart) { + self.parts.push(part) + } + + pub fn append(&mut self, parts: &mut Vec) { + self.parts.append(parts) + } + + pub fn into_cache_key(self) -> (T, CacheKey) { + ( + self.subject, + CacheKey { + version: 0, + prefix: String::new(), + parts: self.parts, + }, + ) + } +} + +#[async_trait] +pub trait Extractor { + type Subject; + async fn get(&self, subject: Self::Subject) -> KeyParts; +} #[async_trait] -pub trait Selector +impl Extractor for Box where - Self: Sized, + T: Extractor + ?Sized + Sync, + T::Subject: Send, +{ + type Subject = T::Subject; + + async fn get(&self, subject: T::Subject) -> KeyParts { + self.as_ref().get(subject).await + } +} + +#[async_trait] +impl Extractor for Arc +where + T: Extractor + Send + Sync + ?Sized, + T::Subject: Send, { - async fn part(&self, subject: Self) -> SelectorPart; + type Subject = T::Subject; + + async fn get(&self, subject: T::Subject) -> KeyParts { + self.as_ref().get(subject).await + } } #[async_trait] @@ -104,13 +178,10 @@ pub trait CacheableRequest where Self: Sized, { - async fn cache_policy

( - self, - predicates: P, - // key_selectors: impl Selector, - ) -> CachePolicy + async fn cache_policy(self, predicates: P, extractors: E) -> CachePolicy where - P: Predicate + Send + Sync; + P: Predicate + Send + Sync, + E: Extractor + Send + Sync; } // #[cfg(test)] diff --git a/hitbox/src/fsm/future.rs b/hitbox/src/fsm/future.rs index a4b9e52..6dbe1c6 100644 --- a/hitbox/src/fsm/future.rs +++ b/hitbox/src/fsm/future.rs @@ -15,7 +15,7 @@ use tracing::{instrument, trace, warn}; use crate::{ backend::CacheBackend, - cache::CacheableRequest, + cache::{CacheKey, CacheableRequest, Extractor}, fsm::{states::StateProj, PollCache, State}, predicates::Predicate, Cacheable, @@ -236,12 +236,14 @@ where transformer: T, backend: Arc, request: Option, + cache_key: Option, #[pin] state: State<::Output, Res, Req>, #[pin] poll_cache: Option>, request_predicates: Arc + Send + Sync>, response_predicates: Arc + Send + Sync>, + key_extractors: Arc + Send + Sync>, } impl CacheFuture @@ -258,15 +260,18 @@ where transformer: T, request_predicates: Arc + Send + Sync>, response_predicates: Arc + Send + Sync>, + key_extractors: Arc + Send + Sync>, ) -> Self { CacheFuture { transformer, backend, + cache_key: None, request: Some(request), state: State::Initial, poll_cache: None, request_predicates, response_predicates, + key_extractors, } } } @@ -286,7 +291,7 @@ where { type Output = T::Response; - #[instrument(skip(self, cx), fields(state = ?self.state, request = type_name::(), backend = type_name::()))] + // #[instrument(skip(self, cx), fields(state = ?self.state, request = type_name::(), backend = type_name::()))] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); @@ -295,8 +300,9 @@ where StateProj::Initial => { let request = this.request.take().expect(POLL_AFTER_READY_ERROR); let predicates = this.request_predicates.clone(); + let extractors = this.key_extractors.clone(); let cache_policy_future = - Box::pin(async move { request.cache_policy(predicates).await }); + Box::pin(async move { request.cache_policy(predicates, extractors).await }); State::CheckRequestCachePolicy { cache_policy_future, } @@ -307,9 +313,10 @@ where let policy = ready!(cache_policy_future.poll(cx)); trace!("{policy:?}"); match policy { - crate::cache::CachePolicy::Cacheable(request) => { + crate::cache::CachePolicy::Cacheable { key, request } => { let backend = this.backend.clone(); - let cache_key = "fake::key".to_owned(); + let cache_key = key.serialize(); + this.cache_key.insert(key); let poll_cache = Box::pin(async move { backend.get::(cache_key).await }); State::PollCache { @@ -373,12 +380,13 @@ where StateProj::CheckResponseCachePolicy { cache_policy } => { let policy = ready!(cache_policy.poll(cx)); let backend = this.backend.clone(); - let cache_key = "fake::key".to_owned(); + let cache_key = this.cache_key.take().expect("CacheKey not found"); match policy { CachePolicy::Cacheable(cache_value) => { let update_cache_future = Box::pin(async move { - let update_cache_result = - backend.set::(cache_key, &cache_value, None).await; + let update_cache_result = backend + .set::(cache_key.serialize(), &cache_value, None) + .await; let upstream_result = Res::from_cached(cache_value.into_inner()).await; (update_cache_result, upstream_result) @@ -408,6 +416,7 @@ where return Poll::Ready(response); } }; + dbg!(&state); this.state.set(state); } }