From 5062e4b794b263d8e918cd2bd1df7a135a458a3b Mon Sep 17 00:00:00 2001 From: bochencwx Date: Fri, 28 Jun 2024 13:00:21 +0800 Subject: [PATCH] Bugfix: Fix some HTTP-related issues - Fix the issue where metrics reports as correct when the HTTP client's asynchronous call returns an erroneous HTTP response code - Fix the issue of rpcz not correctly obtaining the size of the HTTP client response packet - Fix the case sensitivity issue in key comparison for HttpHeader.SetIfNotPresent interfac --- trpc/client/http/http_service_proxy.cc | 3 +- trpc/client/http/http_service_proxy_test.cc | 136 ++++++++++++++++++++ trpc/codec/http/http_protocol.cc | 4 + trpc/codec/http/http_protocol.h | 6 + trpc/codec/http/http_protocol_test.cc | 22 ++-- trpc/util/http/field_map.h | 2 +- trpc/util/http/field_map_test.cc | 18 +++ 7 files changed, 179 insertions(+), 12 deletions(-) diff --git a/trpc/client/http/http_service_proxy.cc b/trpc/client/http/http_service_proxy.cc index b9326e79..7f95c050 100644 --- a/trpc/client/http/http_service_proxy.cc +++ b/trpc/client/http/http_service_proxy.cc @@ -68,14 +68,15 @@ Future HttpServiceProxy::AsyncInnerUnaryInvoke(const ClientContextP ProtocolPtr p = rsp.GetValue0(); // When the call is successful, set the response data (for use by the `CLIENT_POST_RPC_INVOKE` filter). context->SetResponseData(&p); - filter_controller_.RunMessageClientFilters(FilterPoint::CLIENT_POST_RPC_INVOKE, context); if (!CheckHttpResponse(context, p)) { std::string error = fmt::format("service name:{},check http reply failed,{}", GetServiceName(), context->GetStatus().ToString()); TRPC_LOG_ERROR(error); + filter_controller_.RunMessageClientFilters(FilterPoint::CLIENT_POST_RPC_INVOKE, context); return MakeExceptionFuture( CommonException(context->GetStatus().ErrorMessage().c_str(), context->GetStatus().GetFuncRetCode())); } + filter_controller_.RunMessageClientFilters(FilterPoint::CLIENT_POST_RPC_INVOKE, context); return MakeReadyFuture(std::move(p)); } diff --git a/trpc/client/http/http_service_proxy_test.cc b/trpc/client/http/http_service_proxy_test.cc index 687882ce..cdac803a 100644 --- a/trpc/client/http/http_service_proxy_test.cc +++ b/trpc/client/http/http_service_proxy_test.cc @@ -20,6 +20,7 @@ #include "trpc/client/make_client_context.h" #include "trpc/client/service_proxy_option_setter.h" #include "trpc/codec/codec_manager.h" +#include "trpc/filter/filter_manager.h" #include "trpc/future/future_utility.h" #include "trpc/naming/trpc_naming_registry.h" #include "trpc/proto/testing/helloworld.pb.h" @@ -28,6 +29,32 @@ namespace trpc::testing { +class TestClientFilter : public trpc::MessageClientFilter { + public: + std::string Name() override { return "test_filter"; } + + std::vector GetFilterPoint() override { + std::vector points = {trpc::FilterPoint::CLIENT_PRE_RPC_INVOKE, + trpc::FilterPoint::CLIENT_POST_RPC_INVOKE}; + return points; + } + + void operator()(trpc::FilterStatus& status, trpc::FilterPoint point, const trpc::ClientContextPtr& context) override { + status = FilterStatus::CONTINUE; + // record the status upon entering the CLIENT_POST_RPC_INVOKE point + if (point == FilterPoint::CLIENT_POST_RPC_INVOKE) { + status_ = context->GetStatus(); + } + } + + void SetStatus(trpc::Status&& status) { status_ = std::move(status); } + + Status GetStatus() { return status_; } + + private: + trpc::Status status_; +}; + class HttpServiceProxyTest : public ::testing::Test { public: static void SetUpTestCase() { @@ -46,6 +73,10 @@ class HttpServiceProxyTest : public ::testing::Test { codec::Init(); serialization::Init(); naming::Init(); + + filter_ = std::make_shared(); + trpc::FilterManager::GetInstance()->AddMessageClientFilter(filter_); + option_->service_filters.push_back(filter_->Name()); } static void TearDownTestCase() { @@ -58,9 +89,11 @@ class HttpServiceProxyTest : public ::testing::Test { protected: static std::shared_ptr option_; + static std::shared_ptr filter_; }; std::shared_ptr HttpServiceProxyTest::option_ = std::make_shared(); +std::shared_ptr HttpServiceProxyTest::filter_; class MockHttpServiceProxy : public trpc::http::HttpServiceProxy { public: @@ -3891,4 +3924,107 @@ TEST_F(HttpServiceProxyTest, ConstructPBRequest) { } } +TEST_F(HttpServiceProxyTest, FilterExecWhenSuccess) { + auto proxy = std::make_shared(); + proxy->SetMockServiceProxyOption(option_); + auto ctx = MakeClientContext(proxy); + ctx->SetStatus(Status()); + ctx->SetAddr("127.0.0.1", 10002); + + trpc::http::HttpResponse reply; + reply.SetVersion("1.1"); + reply.SetStatus(trpc::http::HttpResponse::StatusCode::kOk); + proxy->SetReplyError(false); + proxy->SetReply(reply); + + std::string rspStr; + filter_->SetStatus(Status()); + auto st = proxy->GetString(ctx, "http://127.0.0.1:10002/hello", &rspStr); + ASSERT_TRUE(st.OK()); + // verify that the status is OK upon entering the RPC post-filter point + ASSERT_TRUE(filter_->GetStatus().OK()); + + proxy->SetReplyError(false); + proxy->SetReply(reply); + ctx = MakeClientContext(proxy); + ctx->SetStatus(Status()); + ctx->SetAddr("127.0.0.1", 10002); + filter_->SetStatus(Status()); + auto async_get_string_fut = + proxy->AsyncGetString(ctx, "http://127.0.0.1:10002/hello").Then([](Future&& future) { + EXPECT_TRUE(future.IsReady()); + // verify that the status is OK upon entering the RPC post-filter point + EXPECT_TRUE(filter_->GetStatus().OK()); + + return MakeReadyFuture<>(); + }); + future::BlockingGet(std::move(async_get_string_fut)); +} + +TEST_F(HttpServiceProxyTest, FilterExecWithFrameworkErr) { + auto proxy = std::make_shared(); + proxy->SetMockServiceProxyOption(option_); + auto ctx = MakeClientContext(proxy); + ctx->SetStatus(Status()); + ctx->SetAddr("127.0.0.1", 10002); + + proxy->SetReplyError(true); + std::string rspStr; + filter_->SetStatus(Status()); + auto st = proxy->GetString(ctx, "http://127.0.0.1:10002/hello", &rspStr); + ASSERT_FALSE(st.OK()); + // verify that the status is already failed upon entering the RPC post-filter point + ASSERT_FALSE(filter_->GetStatus().OK()); + + proxy->SetReplyError(true); + ctx = MakeClientContext(proxy); + ctx->SetStatus(Status()); + ctx->SetAddr("127.0.0.1", 10002); + filter_->SetStatus(Status()); + auto async_get_string_fut = + proxy->AsyncGetString(ctx, "http://127.0.0.1:10002/hello").Then([](Future&& future) { + EXPECT_TRUE(future.IsFailed()); + // verify that the status is already failed upon entering the RPC post-filter point + EXPECT_FALSE(filter_->GetStatus().OK()); + return MakeReadyFuture<>(); + }); + future::BlockingGet(std::move(async_get_string_fut)); +} + +TEST_F(HttpServiceProxyTest, FilterExecWithHttpErr) { + auto proxy = std::make_shared(); + proxy->SetMockServiceProxyOption(option_); + auto ctx = MakeClientContext(proxy); + ctx->SetStatus(Status()); + ctx->SetAddr("127.0.0.1", 10002); + + trpc::http::HttpResponse reply; + reply.SetVersion("1.1"); + reply.SetStatus(trpc::http::HttpResponse::StatusCode::kForbidden); + proxy->SetReplyError(false); + proxy->SetReply(reply); + + std::string rspStr; + filter_->SetStatus(Status()); + auto st = proxy->GetString(ctx, "http://127.0.0.1:10002/hello", &rspStr); + ASSERT_FALSE(st.OK()); + // verify that the status is already failed upon entering the RPC post-filter point + ASSERT_FALSE(filter_->GetStatus().OK()); + + proxy->SetReplyError(false); + proxy->SetReply(reply); + ctx = MakeClientContext(proxy); + ctx->SetStatus(Status()); + ctx->SetAddr("127.0.0.1", 10002); + filter_->SetStatus(Status()); + auto async_get_string_fut = + proxy->AsyncGetString(ctx, "http://127.0.0.1:10002/hello").Then([](Future&& future) { + EXPECT_TRUE(future.IsFailed()); + // verify that the status is already failed upon entering the RPC post-filter point + EXPECT_FALSE(filter_->GetStatus().OK()); + return MakeReadyFuture<>(); + }); + future::BlockingGet(std::move(async_get_string_fut)); +} + } // namespace trpc::testing diff --git a/trpc/codec/http/http_protocol.cc b/trpc/codec/http/http_protocol.cc index 92ac2b45..2d5d27b2 100644 --- a/trpc/codec/http/http_protocol.cc +++ b/trpc/codec/http/http_protocol.cc @@ -57,6 +57,10 @@ bool HttpRequestProtocol::WaitForFullRequest() { return false; } +uint32_t HttpRequestProtocol::GetMessageSize() const { return request->ContentLength(); } + +uint32_t HttpResponseProtocol::GetMessageSize() const { return response.ContentLength(); } + namespace internal { const std::string& GetHeaderString(const http::HeaderPairs& header, const std::string& name) { return header.Get(name); diff --git a/trpc/codec/http/http_protocol.h b/trpc/codec/http/http_protocol.h index 80a070ed..a51090c2 100644 --- a/trpc/codec/http/http_protocol.h +++ b/trpc/codec/http/http_protocol.h @@ -72,6 +72,9 @@ class HttpRequestProtocol : public Protocol { void SetFromHttpServiceProxy(bool from_http_service_proxy) { from_http_service_proxy_ = from_http_service_proxy; } bool IsFromHttpServiceProxy() { return from_http_service_proxy_; } + /// @brief Get size of message + uint32_t GetMessageSize() const override; + public: uint32_t request_id{0}; http::RequestPtr request{nullptr}; @@ -101,6 +104,9 @@ class HttpResponseProtocol : public Protocol { return std::move(*response.GetMutableNonContiguousBufferContent()); } + /// @brief Get size of message + uint32_t GetMessageSize() const override; + public: http::Response response; }; diff --git a/trpc/codec/http/http_protocol_test.cc b/trpc/codec/http/http_protocol_test.cc index fabb0bba..c8255e0f 100644 --- a/trpc/codec/http/http_protocol_test.cc +++ b/trpc/codec/http/http_protocol_test.cc @@ -35,8 +35,9 @@ TEST_F(HttpProtoFixture, HttpRequestProtocolTest) { HttpRequestProtocol req = HttpRequestProtocol(std::make_shared()); req.request->SetContent(test); NoncontiguousBuffer buff; - EXPECT_FALSE(req.ZeroCopyDecode(buff)); - EXPECT_FALSE(req.ZeroCopyEncode(buff)); + ASSERT_FALSE(req.ZeroCopyDecode(buff)); + ASSERT_FALSE(req.ZeroCopyEncode(buff)); + ASSERT_NE(0, req.GetMessageSize()); } TEST_F(HttpProtoFixture, HttpResponseProtocolTest) { @@ -47,8 +48,9 @@ TEST_F(HttpProtoFixture, HttpResponseProtocolTest) { reply.SetStatus(trpc::http::HttpResponse::StatusCode::kOk); reply.SetContent("{\"age\":\"18\",\"height\":180}"); NoncontiguousBuffer buff; - EXPECT_FALSE(rsp.ZeroCopyDecode(buff)); - EXPECT_FALSE(rsp.ZeroCopyEncode(buff)); + ASSERT_FALSE(rsp.ZeroCopyDecode(buff)); + ASSERT_FALSE(rsp.ZeroCopyEncode(buff)); + ASSERT_NE(0, rsp.GetMessageSize()); } TEST(HttpRequestProtocolTest, GetOkNonContiguousProtocolBody) { @@ -59,13 +61,13 @@ TEST(HttpRequestProtocolTest, GetOkNonContiguousProtocolBody) { request_protocol.SetNonContiguousProtocolBody(builder.DestructiveGet()); auto body_buffer = request_protocol.GetNonContiguousProtocolBody(); - EXPECT_EQ(greetings.size(), body_buffer.ByteSize()); + ASSERT_EQ(greetings.size(), body_buffer.ByteSize()); } TEST(HttpRequestProtocolTest, GetEmptyNonContiguousProtocolBody) { HttpRequestProtocol request_protocol{std::make_shared()}; auto body_buffer = request_protocol.GetNonContiguousProtocolBody(); - EXPECT_EQ(0, body_buffer.size()); + ASSERT_EQ(0, body_buffer.size()); } TEST(HttpResponseProtocolTest, GetOkNonContiguousProtocolBody) { @@ -74,8 +76,8 @@ TEST(HttpResponseProtocolTest, GetOkNonContiguousProtocolBody) { response_protocol.SetNonContiguousProtocolBody(CreateBufferSlow(greetings)); - EXPECT_EQ(greetings.size(), response_protocol.GetNonContiguousProtocolBody().ByteSize()); - EXPECT_TRUE(response_protocol.response.GetContent().empty()); + ASSERT_EQ(greetings.size(), response_protocol.GetNonContiguousProtocolBody().ByteSize()); + ASSERT_TRUE(response_protocol.response.GetContent().empty()); } TEST(EncodeTypeToMimeTest, EncodeTypeToMime) { @@ -92,7 +94,7 @@ TEST(EncodeTypeToMimeTest, EncodeTypeToMime) { }; for (const auto& t : testings) { - EXPECT_EQ(t.expect, EncodeTypeToMime(t.encode_type)); + ASSERT_EQ(t.expect, EncodeTypeToMime(t.encode_type)); } } @@ -112,7 +114,7 @@ TEST(MimeToEncodeTypeTest, MimeToEncodeType) { }; for (const auto& t : testings) { - EXPECT_EQ(t.expect, MimeToEncodeType(t.mime)); + ASSERT_EQ(t.expect, MimeToEncodeType(t.mime)); } } diff --git a/trpc/util/http/field_map.h b/trpc/util/http/field_map.h index 5f113c57..36115479 100644 --- a/trpc/util/http/field_map.h +++ b/trpc/util/http/field_map.h @@ -79,7 +79,7 @@ class FieldMap { /// @brief Does same thing as Set method, but only works when key does not exist in the map. template void SetIfNotPresent(K&& key, V&& value) { - if (auto it = pairs_.lower_bound(key); it == pairs_.end() || it->first != key) { + if (auto it = pairs_.lower_bound(key); it == pairs_.end() || !CaseInsensitiveEqualTo()(it->first, key)) { pairs_.emplace_hint(it, std::forward(key), std::forward(value)); } } diff --git a/trpc/util/http/field_map_test.cc b/trpc/util/http/field_map_test.cc index a5870f0d..d3ede10b 100644 --- a/trpc/util/http/field_map_test.cc +++ b/trpc/util/http/field_map_test.cc @@ -207,4 +207,22 @@ TEST_F(FieldMapTest, FlatPairsCount) { ASSERT_EQ(17, header_.FlatPairsCount()); } +TEST_F(FieldMapTest, SetIfNotPresentOk) { + EXPECT_EQ(0, header_.Values("User-Defined-Key99").size()); + header_.SetIfNotPresent("User-Defined-Key99", "user-defined-value99"); + EXPECT_EQ(1, header_.Values("User-Defined-Key99").size()); + EXPECT_EQ("user-defined-value99", header_.Get("User-Defined-Key99")); + + EXPECT_EQ(1, header_.Values("User-Defined-Key01").size()); + header_.SetIfNotPresent("User-Defined-Key01", "user-defined-value01-new"); + EXPECT_EQ(1, header_.Values("User-Defined-Key01").size()); + EXPECT_EQ("user-defined-value01", header_.Get("User-Defined-Key01")); + + // case insensitivity test + EXPECT_EQ(1, header_.Values("user-defined-key01").size()); + header_.SetIfNotPresent("user-defined-key01", "user-defined-value01-new"); + EXPECT_EQ(1, header_.Values("user-defined-key01").size()); + EXPECT_EQ("user-defined-value01", header_.Get("user-defined-key01")); +} + } // namespace trpc::http::testing