From 36b1378b88287ad8d67371fae9bbb096d650eaf8 Mon Sep 17 00:00:00 2001 From: Matt Klein Date: Mon, 5 Sep 2016 15:10:11 -0700 Subject: [PATCH] grpc: allow request header customization before dispatch --- include/envoy/grpc/rpc_channel.h | 7 +++++++ source/common/grpc/rpc_channel_impl.cc | 1 + source/common/ratelimit/ratelimit_impl.h | 1 + test/common/grpc/rpc_channel_impl_test.cc | 18 +++++++++++++++++- test/mocks/grpc/mocks.h | 1 + 5 files changed, 27 insertions(+), 1 deletion(-) diff --git a/include/envoy/grpc/rpc_channel.h b/include/envoy/grpc/rpc_channel.h index 8be4d7ca2fba..ef6a2b269f5a 100644 --- a/include/envoy/grpc/rpc_channel.h +++ b/include/envoy/grpc/rpc_channel.h @@ -2,6 +2,7 @@ #include "envoy/common/optional.h" #include "envoy/common/pure.h" +#include "envoy/http/header_map.h" #include "google/protobuf/service.h" @@ -16,6 +17,12 @@ class RpcChannelCallbacks { public: virtual ~RpcChannelCallbacks() {} + /** + * Called before the channel dispatches an HTTP/2 request. This can be used to customize the + * transport headers for the RPC. + */ + virtual void onPreRequestCustomizeHeaders(Http::HeaderMap& headers) PURE; + /** * Called when the request has succeeded and the response object is populated. */ diff --git a/source/common/grpc/rpc_channel_impl.cc b/source/common/grpc/rpc_channel_impl.cc index 09816ab5737c..a3bf4feca153 100644 --- a/source/common/grpc/rpc_channel_impl.cc +++ b/source/common/grpc/rpc_channel_impl.cc @@ -40,6 +40,7 @@ void RpcChannelImpl::CallMethod(const proto::MethodDescriptor* method, proto::Rp message->headers().addViaCopy(Http::Headers::get().ContentType, Common::GRPC_CONTENT_TYPE); message->body(serializeBody(*grpc_request)); + callbacks_.onPreRequestCustomizeHeaders(message->headers()); http_request_ = cm_.httpAsyncClientForCluster(cluster_).send(std::move(message), *this, timeout_); } diff --git a/source/common/ratelimit/ratelimit_impl.h b/source/common/ratelimit/ratelimit_impl.h index 413dec6a6678..2f565de37984 100644 --- a/source/common/ratelimit/ratelimit_impl.h +++ b/source/common/ratelimit/ratelimit_impl.h @@ -25,6 +25,7 @@ class GrpcClientImpl : public Client, public Grpc::RpcChannelCallbacks { private: // Grpc::RpcChannelCallbacks + void onPreRequestCustomizeHeaders(Http::HeaderMap&) override {} void onSuccess() override; void onFailure(const Optional& grpc_status, const std::string& message) override; diff --git a/test/common/grpc/rpc_channel_impl_test.cc b/test/common/grpc/rpc_channel_impl_test.cc index ed7744e85168..ab7c5fc8516d 100644 --- a/test/common/grpc/rpc_channel_impl_test.cc +++ b/test/common/grpc/rpc_channel_impl_test.cc @@ -44,13 +44,16 @@ TEST_F(GrpcRequestImplTest, NoError) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)) + .WillOnce(Invoke([](Http::HeaderMap& headers) -> void { headers.addViaCopy("foo", "bar"); })); service_.SayHello(nullptr, &request, &response, nullptr); Http::HeaderMapImpl expected_request_headers{{":scheme", "http"}, {":method", "POST"}, {":path", "/helloworld.Greeter/SayHello"}, {":authority", "cluster"}, - {"content-type", "application/grpc"}}; + {"content-type", "application/grpc"}, + {"foo", "bar"}}; EXPECT_THAT(http_request_->headers(), HeaderMapEqualRef(expected_request_headers)); @@ -77,6 +80,7 @@ TEST_F(GrpcRequestImplTest, Non200Response) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message(new Http::ResponseMessageImpl( @@ -96,6 +100,7 @@ TEST_F(GrpcRequestImplTest, NoResponseTrailers) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message(new Http::ResponseMessageImpl( @@ -111,6 +116,7 @@ TEST_F(GrpcRequestImplTest, BadGrpcStatusInHeaderOnlyResponse) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message(new Http::ResponseMessageImpl( @@ -126,6 +132,7 @@ TEST_F(GrpcRequestImplTest, HeaderOnlyFailure) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message( @@ -142,6 +149,7 @@ TEST_F(GrpcRequestImplTest, BadGrpcStatusInResponse) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message(new Http::ResponseMessageImpl( @@ -158,6 +166,7 @@ TEST_F(GrpcRequestImplTest, GrpcStatusNonZeroInResponse) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message(new Http::ResponseMessageImpl( @@ -175,6 +184,7 @@ TEST_F(GrpcRequestImplTest, ShortBodyInResponse) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message(new Http::ResponseMessageImpl( @@ -193,6 +203,7 @@ TEST_F(GrpcRequestImplTest, BadMessageInResponse) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message(new Http::ResponseMessageImpl( @@ -211,6 +222,7 @@ TEST_F(GrpcRequestImplTest, HttpAsyncRequestFailure) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); EXPECT_CALL(grpc_callbacks_, onFailure(Optional(), "stream reset")); @@ -223,6 +235,7 @@ TEST_F(GrpcRequestImplTest, HttpAsyncRequestTimeout) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); EXPECT_CALL(grpc_callbacks_, onFailure(Optional(), "request timeout")); @@ -243,6 +256,7 @@ TEST_F(GrpcRequestImplTest, NoHttpAsyncRequest) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); } @@ -252,6 +266,7 @@ TEST_F(GrpcRequestImplTest, Cancel) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_.SayHello(nullptr, &request, &response, nullptr); EXPECT_CALL(http_async_client_request_, cancel()); @@ -267,6 +282,7 @@ TEST_F(GrpcRequestImplTest, RequestTimeoutSet) { helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; + EXPECT_CALL(grpc_callbacks_, onPreRequestCustomizeHeaders(_)); service_timeout.SayHello(nullptr, &request, &response, nullptr); Http::MessagePtr response_http_message(new Http::ResponseMessageImpl( diff --git a/test/mocks/grpc/mocks.h b/test/mocks/grpc/mocks.h index 5a77181d8796..65a3671f2337 100644 --- a/test/mocks/grpc/mocks.h +++ b/test/mocks/grpc/mocks.h @@ -9,6 +9,7 @@ class MockRpcChannelCallbacks : public RpcChannelCallbacks { MockRpcChannelCallbacks(); ~MockRpcChannelCallbacks(); + MOCK_METHOD1(onPreRequestCustomizeHeaders, void(Http::HeaderMap& headers)); MOCK_METHOD0(onSuccess, void()); MOCK_METHOD2(onFailure, void(const Optional& grpc_status, const std::string& message)); };