From 36a5c42ea50dfdc518a2a6475a3b4c45af05daa4 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Wed, 13 Nov 2024 11:18:48 +0000 Subject: [PATCH 1/2] Add response headers in correct phase Signed-off-by: Adam Cattermole --- src/filter/http_context.rs | 8 +++++++- src/service.rs | 9 ++++++--- src/service/rate_limit.rs | 12 ++++-------- tests/rate_limited.rs | 12 ++++++------ 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 69a8c0c..f23ccb2 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -142,7 +142,13 @@ impl Context for Filter { match op_res { Ok(operation) => { - if GrpcService::process_grpc_response(operation, resp_size).is_ok() { + if GrpcService::process_grpc_response( + operation, + resp_size, + &mut self.response_headers_to_add, + ) + .is_ok() + { // call the next op match self.operation_dispatcher.borrow_mut().next() { Ok(some_op) => { diff --git a/src/service.rs b/src/service.rs index 575281a..617f4f7 100644 --- a/src/service.rs +++ b/src/service.rs @@ -54,6 +54,7 @@ impl GrpcService { pub fn process_grpc_response( operation: Rc, resp_size: usize, + response_headers_to_add: &mut Vec<(String, String)>, ) -> Result<(), StatusCode> { let failure_mode = operation.get_failure_mode(); if let Some(res_body_bytes) = @@ -62,9 +63,11 @@ impl GrpcService { match GrpcMessageResponse::new(operation.get_service_type(), &res_body_bytes) { Ok(res) => match operation.get_service_type() { ServiceType::Auth => AuthService::process_auth_grpc_response(res, failure_mode), - ServiceType::RateLimit => { - RateLimitService::process_ratelimit_grpc_response(res, failure_mode) - } + ServiceType::RateLimit => RateLimitService::process_ratelimit_grpc_response( + res, + failure_mode, + response_headers_to_add, + ), }, Err(e) => { warn!( diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 2f97f31..275817a 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -7,7 +7,7 @@ use crate::service::GrpcService; use log::warn; use protobuf::{Message, RepeatedField}; use proxy_wasm::hostcalls; -use proxy_wasm::types::{Bytes, MapType}; +use proxy_wasm::types::Bytes; pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; @@ -38,6 +38,7 @@ impl RateLimitService { pub fn process_ratelimit_grpc_response( rl_resp: GrpcMessageResponse, failure_mode: FailureMode, + response_headers_to_add: &mut Vec<(String, String)>, ) -> Result<(), StatusCode> { match rl_resp { GrpcMessageResponse::RateLimit(RateLimitResponse { @@ -65,14 +66,9 @@ impl RateLimitService { response_headers_to_add: additional_headers, .. }) => { - // TODO: This should not be sent to the upstream! additional_headers.iter().for_each(|header| { - hostcalls::add_map_value( - MapType::HttpResponseHeaders, - header.get_key(), - header.get_value(), - ) - .unwrap() + response_headers_to_add + .push((header.get_key().to_owned(), header.get_value().to_owned())) }); Ok(()) } diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index 0c0e72b..b88daa6 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -353,6 +353,12 @@ fn it_passes_additional_headers() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_response_headers(http_context, 0, false) + .expect_log(Some(LogLevel::Debug), Some("#2 on_http_response_headers")) .expect_add_header_map_value( Some(MapType::HttpResponseHeaders), Some("test"), @@ -363,12 +369,6 @@ fn it_passes_additional_headers() { Some("other"), Some("header value"), ) - .execute_and_expect(ReturnType::None) - .unwrap(); - - module - .call_proxy_on_response_headers(http_context, 0, false) - .expect_log(Some(LogLevel::Debug), Some("#2 on_http_response_headers")) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } From 9f280ccbaeb77acf88ddea38537615aae7bae6b1 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Wed, 13 Nov 2024 12:32:02 +0000 Subject: [PATCH 2/2] Return the response_headers to add at the filter level Signed-off-by: Adam Cattermole --- src/filter/http_context.rs | 10 +++------- src/service.rs | 25 ++++++++++++++++++------- src/service/auth.rs | 6 +++--- src/service/rate_limit.rs | 17 +++++++++-------- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index f23ccb2..e9204d1 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -142,13 +142,9 @@ impl Context for Filter { match op_res { Ok(operation) => { - if GrpcService::process_grpc_response( - operation, - resp_size, - &mut self.response_headers_to_add, - ) - .is_ok() - { + if let Ok(result) = GrpcService::process_grpc_response(operation, resp_size) { + // add the response headers + self.response_headers_to_add.extend(result.response_headers); // call the next op match self.operation_dispatcher.borrow_mut().next() { Ok(some_op) => { diff --git a/src/service.rs b/src/service.rs index 617f4f7..404671f 100644 --- a/src/service.rs +++ b/src/service.rs @@ -54,8 +54,7 @@ impl GrpcService { pub fn process_grpc_response( operation: Rc, resp_size: usize, - response_headers_to_add: &mut Vec<(String, String)>, - ) -> Result<(), StatusCode> { + ) -> Result { let failure_mode = operation.get_failure_mode(); if let Some(res_body_bytes) = hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, 0, resp_size).unwrap() @@ -63,11 +62,9 @@ impl GrpcService { match GrpcMessageResponse::new(operation.get_service_type(), &res_body_bytes) { Ok(res) => match operation.get_service_type() { ServiceType::Auth => AuthService::process_auth_grpc_response(res, failure_mode), - ServiceType::RateLimit => RateLimitService::process_ratelimit_grpc_response( - res, - failure_mode, - response_headers_to_add, - ), + ServiceType::RateLimit => { + RateLimitService::process_ratelimit_grpc_response(res, failure_mode) + } }, Err(e) => { warn!( @@ -95,6 +92,20 @@ impl GrpcService { } } +pub struct GrpcResult { + pub response_headers: Vec<(String, String)>, +} +impl GrpcResult { + pub fn default() -> Self { + Self { + response_headers: Vec::new(), + } + } + pub fn new(response_headers: Vec<(String, String)>) -> Self { + Self { response_headers } + } +} + pub type GrpcCallFn = fn( upstream_name: &str, service_name: &str, diff --git a/src/service/auth.rs b/src/service/auth.rs index a1e4e96..33af068 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -6,7 +6,7 @@ use crate::envoy::{ SocketAddress, StatusCode, }; use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; -use crate::service::GrpcService; +use crate::service::{GrpcResult, GrpcService}; use chrono::{DateTime, FixedOffset}; use log::{debug, warn}; use protobuf::well_known_types::Timestamp; @@ -125,7 +125,7 @@ impl AuthService { pub fn process_auth_grpc_response( auth_resp: GrpcMessageResponse, failure_mode: FailureMode, - ) -> Result<(), StatusCode> { + ) -> Result { if let GrpcMessageResponse::Auth(check_response) = auth_resp { // store dynamic metadata in filter state store_metadata(check_response.get_dynamic_metadata()); @@ -153,7 +153,7 @@ impl AuthService { ) .unwrap() }); - Ok(()) + Ok(GrpcResult::default()) } Some(CheckResponse_oneof_http_response::denied_response(denied_response)) => { debug!("process_auth_grpc_response: received DeniedHttpResponse"); diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 275817a..4d8f242 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -3,7 +3,7 @@ use crate::envoy::{ RateLimitDescriptor, RateLimitRequest, RateLimitResponse, RateLimitResponse_Code, StatusCode, }; use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; -use crate::service::GrpcService; +use crate::service::{GrpcResult, GrpcService}; use log::warn; use protobuf::{Message, RepeatedField}; use proxy_wasm::hostcalls; @@ -38,8 +38,7 @@ impl RateLimitService { pub fn process_ratelimit_grpc_response( rl_resp: GrpcMessageResponse, failure_mode: FailureMode, - response_headers_to_add: &mut Vec<(String, String)>, - ) -> Result<(), StatusCode> { + ) -> Result { match rl_resp { GrpcMessageResponse::RateLimit(RateLimitResponse { overall_code: RateLimitResponse_Code::UNKNOWN, @@ -66,11 +65,13 @@ impl RateLimitService { response_headers_to_add: additional_headers, .. }) => { - additional_headers.iter().for_each(|header| { - response_headers_to_add - .push((header.get_key().to_owned(), header.get_value().to_owned())) - }); - Ok(()) + let result = GrpcResult::new( + additional_headers + .iter() + .map(|header| (header.get_key().to_owned(), header.get_value().to_owned())) + .collect(), + ); + Ok(result) } _ => { warn!("not a valid GrpcMessageResponse::RateLimit(RateLimitResponse)!");