Skip to content

Commit

Permalink
feat userver: create middleware exception, make it possible to throw …
Browse files Browse the repository at this point in the history
…in response hooks

Tests: CI
commit_hash:e0f55f096013ee6dba2da88a84922fdb6f3cd7d7
  • Loading branch information
abramov-alex committed Oct 9, 2024
1 parent ead0775 commit be7b0e1
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 18 deletions.
2 changes: 2 additions & 0 deletions .mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -1883,6 +1883,7 @@
"grpc/include/userver/ugrpc/server/impl/codegen_definitions.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/server/impl/codegen_definitions.hpp",
"grpc/include/userver/ugrpc/server/impl/completion_queue_pool.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/server/impl/completion_queue_pool.hpp",
"grpc/include/userver/ugrpc/server/impl/error_code.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/server/impl/error_code.hpp",
"grpc/include/userver/ugrpc/server/impl/exceptions.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/server/impl/exceptions.hpp",
"grpc/include/userver/ugrpc/server/impl/service_worker.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/server/impl/service_worker.hpp",
"grpc/include/userver/ugrpc/server/impl/service_worker_impl.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/server/impl/service_worker_impl.hpp",
"grpc/include/userver/ugrpc/server/middlewares/baggage/component.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/server/middlewares/baggage/component.hpp",
Expand Down Expand Up @@ -2021,6 +2022,7 @@
"grpc/tests/middlewares_test.cpp":"taxi/uservices/userver/grpc/tests/middlewares_test.cpp",
"grpc/tests/secret_fields_test.cpp":"taxi/uservices/userver/grpc/tests/secret_fields_test.cpp",
"grpc/tests/serialization_test.cpp":"taxi/uservices/userver/grpc/tests/serialization_test.cpp",
"grpc/tests/server_middlewares_error_test.cpp":"taxi/uservices/userver/grpc/tests/server_middlewares_error_test.cpp",
"grpc/tests/service_config_test.cpp":"taxi/uservices/userver/grpc/tests/service_config_test.cpp",
"grpc/tests/statistics_test.cpp":"taxi/uservices/userver/grpc/tests/statistics_test.cpp",
"grpc/tests/stream_test.cpp":"taxi/uservices/userver/grpc/tests/stream_test.cpp",
Expand Down
17 changes: 17 additions & 0 deletions grpc/include/userver/ugrpc/server/impl/exceptions.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include <exception>

USERVER_NAMESPACE_BEGIN

namespace ugrpc::server::impl {

/// @brief Base class for userver-internal rpc errors
class BaseInternalRpcError : public std::exception {};

/// @brief Middleware interruption
class MiddlewareRpcInterruptionError : public BaseInternalRpcError {};

} // namespace ugrpc::server::impl

USERVER_NAMESPACE_END
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#pragma once

#include <chrono>
#include <exception>
#include <functional>
#include <optional>
#include <string>
#include <string_view>
#include <type_traits>
#include <utility>
Expand Down Expand Up @@ -32,6 +30,7 @@
#include <userver/ugrpc/server/impl/call_params.hpp>
#include <userver/ugrpc/server/impl/call_traits.hpp>
#include <userver/ugrpc/server/impl/error_code.hpp>
#include <userver/ugrpc/server/impl/exceptions.hpp>
#include <userver/ugrpc/server/impl/service_worker.hpp>
#include <userver/ugrpc/server/middlewares/base.hpp>
#include <userver/ugrpc/server/rpc.hpp>
Expand Down Expand Up @@ -212,6 +211,10 @@ class CallData final {
initial_request);
responder.RunMiddlewarePipeline(utils::impl::InternalTag{},
middleware_context);
} catch (const BaseInternalRpcError&) {
// The status has already been reported by user code with FinishWithError.
// The exception is required to rollback the call stack of the handler.
// Thus, we should just ignore the error.
} catch (
const USERVER_NAMESPACE::server::handlers::CustomHandlerException& ex) {
ReportCustomError(ex, responder, span_->Get());
Expand Down
38 changes: 22 additions & 16 deletions grpc/include/userver/ugrpc/server/rpc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,12 @@ void UnaryCall<Response>::Finish(Response&& response) {
template <typename Response>
void UnaryCall<Response>::Finish(Response& response) {
UINVARIANT(!is_finished_, "'Finish' called on a finished call");
is_finished_ = true;

ApplyResponseHook(&response);

// It is important to set is_finished_ after ApplyResponseHook.
// Otherwise, there would be no way to call FinishWithError there.
is_finished_ = true;

LogFinish(grpc::Status::OK);
impl::Finish(stream_, response, grpc::Status::OK, GetCallName());
GetStatistics().OnExplicitFinish(grpc::StatusCode::OK);
Expand Down Expand Up @@ -359,13 +361,15 @@ template <typename Request, typename Response>
void InputStream<Request, Response>::Finish(Response& response) {
UINVARIANT(state_ != State::kFinished,
"'Finish' called on a finished stream");
ApplyResponseHook(&response);

// It is important to set the state_ after ApplyResponseHook.
// Otherwise, there would be no way to call FinishWithError there.
state_ = State::kFinished;

const auto& status = grpc::Status::OK;
LogFinish(status);

ApplyResponseHook(&response);

impl::Finish(stream_, response, status, GetCallName());
GetStatistics().OnExplicitFinish(status.error_code());
ugrpc::impl::UpdateSpanWithStatus(GetSpan(), status);
Expand Down Expand Up @@ -411,6 +415,7 @@ void OutputStream<Response>::Write(Response&& response) {
template <typename Response>
void OutputStream<Response>::Write(Response& response) {
UINVARIANT(state_ != State::kFinished, "'Write' called on a finished stream");
ApplyResponseHook(&response);

// For some reason, gRPC requires explicit 'SendInitialMetadata' in output
// streams
Expand All @@ -420,8 +425,6 @@ void OutputStream<Response>::Write(Response& response) {
// may never actually be delivered
grpc::WriteOptions write_options{};

ApplyResponseHook(&response);

impl::Write(stream_, response, write_options, GetCallName());
}

Expand Down Expand Up @@ -458,6 +461,10 @@ template <typename Response>
void OutputStream<Response>::WriteAndFinish(Response& response) {
UINVARIANT(state_ != State::kFinished,
"'WriteAndFinish' called on a finished stream");
ApplyResponseHook(&response);

// It is important to set the state_ after ApplyResponseHook.
// Otherwise, there would be no way to call FinishWithError there.
state_ = State::kFinished;

// Don't buffer writes, otherwise in an event subscription scenario, events
Expand All @@ -467,8 +474,6 @@ void OutputStream<Response>::WriteAndFinish(Response& response) {
const auto& status = grpc::Status::OK;
LogFinish(status);

ApplyResponseHook(&response);

impl::WriteAndFinish(stream_, response, write_options, status, GetCallName());
GetStatistics().OnExplicitFinish(grpc::StatusCode::OK);
ugrpc::impl::UpdateSpanWithStatus(GetSpan(), status);
Expand Down Expand Up @@ -518,14 +523,13 @@ void BidirectionalStream<Request, Response>::Write(Response&& response) {
template <typename Request, typename Response>
void BidirectionalStream<Request, Response>::Write(Response& response) {
UINVARIANT(!is_finished_, "'Write' called on a finished stream");

// Don't buffer writes, optimize for ping-pong-style interaction
grpc::WriteOptions write_options{};

if constexpr (std::is_base_of_v<google::protobuf::Message, Response>) {
ApplyResponseHook(&response);
}

// Don't buffer writes, optimize for ping-pong-style interaction
grpc::WriteOptions write_options{};

try {
impl::Write(stream_, response, write_options, GetCallName());
} catch (const RpcInterruptedError&) {
Expand Down Expand Up @@ -568,6 +572,12 @@ template <typename Request, typename Response>
void BidirectionalStream<Request, Response>::WriteAndFinish(
Response& response) {
UINVARIANT(!is_finished_, "'WriteAndFinish' called on a finished stream");
if constexpr (std::is_base_of_v<google::protobuf::Message, Response>) {
ApplyResponseHook(&response);
}

// It is important to set is_finished_ after ApplyResponseHook.
// Otherwise, there would be no way to call FinishWithError there.
is_finished_ = true;

// Don't buffer writes, optimize for ping-pong-style interaction
Expand All @@ -576,10 +586,6 @@ void BidirectionalStream<Request, Response>::WriteAndFinish(
const auto& status = grpc::Status::OK;
LogFinish(status);

if constexpr (std::is_base_of_v<google::protobuf::Message, Response>) {
ApplyResponseHook(&response);
}

impl::WriteAndFinish(stream_, response, write_options, status, GetCallName());
GetStatistics().OnExplicitFinish(status.error_code());
ugrpc::impl::UpdateSpanWithStatus(GetSpan(), status);
Expand Down
3 changes: 3 additions & 0 deletions grpc/src/ugrpc/server/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <userver/logging/impl/logger_base.hpp>
#include <userver/logging/logger.hpp>
#include <userver/ugrpc/impl/statistics_storage.hpp>
#include <userver/ugrpc/server/impl/exceptions.hpp>
#include <userver/ugrpc/server/middlewares/base.hpp>
#include <userver/utils/algo.hpp>

Expand Down Expand Up @@ -50,6 +51,7 @@ void CallAnyBase::ApplyRequestHook(google::protobuf::Message* request) {
if (request) {
for (const auto& middleware : params_.middlewares) {
middleware->CallRequestHook(*middleware_call_context_, *request);
if (IsFinished()) throw impl::MiddlewareRpcInterruptionError();
}
}
}
Expand All @@ -60,6 +62,7 @@ void CallAnyBase::ApplyResponseHook(google::protobuf::Message* response) {
for (const auto& middleware :
boost::adaptors::reverse(params_.middlewares)) {
middleware->CallResponseHook(*middleware_call_context_, *response);
if (IsFinished()) throw impl::MiddlewareRpcInterruptionError();
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions grpc/src/ugrpc/server/middlewares/base.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <userver/ugrpc/server/middlewares/base.hpp>

#include <userver/ugrpc/server/impl/exceptions.hpp>

USERVER_NAMESPACE_BEGIN

namespace ugrpc::server {
Expand All @@ -25,6 +27,7 @@ void MiddlewareCallContext::Next() {
// It is important for non-stream calls
if (request_) {
(*middleware_)->CallRequestHook(*this, *request_);
if (call_.IsFinished()) throw impl::MiddlewareRpcInterruptionError();
}
++middleware_;
}
Expand Down
112 changes: 112 additions & 0 deletions grpc/tests/server_middlewares_error_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include <userver/utest/utest.hpp>

#include <grpcpp/support/status.h>

#include <userver/ugrpc/server/middlewares/base.hpp>
#include <userver/ugrpc/tests/service_fixtures.hpp>
#include <userver/utest/log_capture_fixture.hpp>
#include <userver/utils/flags.hpp>

#include <tests/unit_test_client.usrv.pb.hpp>
#include <tests/unit_test_service.usrv.pb.hpp>

USERVER_NAMESPACE_BEGIN

namespace {

class Messenger final : public sample::ugrpc::UnitTestServiceBase {
public:
void SayHello(SayHelloCall& call,
sample::ugrpc::GreetingRequest&& /*request*/) override {
call.Finish(sample::ugrpc::GreetingResponse{});
}
};

enum class MiddlewareFlag {
kNone = 0,
kErrorInRequestHook = 1 << 0,
kErrorInResponseHook = 1 << 1
};

using MiddlewareFlags = utils::Flags<MiddlewareFlag>;

class Middleware final : public ugrpc::server::MiddlewareBase {
public:
Middleware(MiddlewareFlag settings) : settings_(settings) {}

void Handle(ugrpc::server::MiddlewareCallContext& context) const override {
context.Next();
}

void CallRequestHook(const ugrpc::server::MiddlewareCallContext& context,
google::protobuf::Message&) override {
if (settings_ == MiddlewareFlag::kErrorInRequestHook) {
context.GetCall().FinishWithError(::grpc::Status(
::grpc::StatusCode::DATA_LOSS, "Data loss error in request hook"));
}
}

void CallResponseHook(const ugrpc::server::MiddlewareCallContext& context,
google::protobuf::Message&) override {
if (settings_ == MiddlewareFlag::kErrorInResponseHook) {
context.GetCall().FinishWithError(
::grpc::Status(::grpc::StatusCode::OUT_OF_RANGE,
"Out of range error in response hook"));
}
}

private:
MiddlewareFlag settings_;
};

class MockMessengerServiceFixture
: public ugrpc::tests::ServiceFixtureBase,
public testing::WithParamInterface<MiddlewareFlags> {
protected:
MockMessengerServiceFixture() {
SetServerMiddlewares({std::make_shared<Middleware>(
static_cast<MiddlewareFlag>(GetParam().GetValue()))});
RegisterService(service_);
StartServer();
}

private:
Messenger service_;
};

} // namespace

UTEST_P(MockMessengerServiceFixture, MiddlewareInterruption) {
const auto client = MakeClient<sample::ugrpc::UnitTestServiceClient>();
try {
client.SayHello(sample::ugrpc::GreetingRequest()).Finish();
FAIL(); // Should not execute. The method must throw.
} catch (const ugrpc::client::ErrorWithStatus& error) {
switch (static_cast<MiddlewareFlag>(GetParam().GetValue())) {
case MiddlewareFlag::kErrorInRequestHook: {
EXPECT_EQ(error.GetStatus().error_code(),
::grpc::StatusCode::DATA_LOSS);
EXPECT_EQ(error.GetStatus().error_message(),
"Data loss error in request hook");
break;
}
case MiddlewareFlag::kErrorInResponseHook: {
EXPECT_EQ(error.GetStatus().error_code(),
::grpc::StatusCode::OUT_OF_RANGE);
EXPECT_EQ(error.GetStatus().error_message(),
"Out of range error in response hook");
break;
}
default: {
FAIL(); // Should not happen
}
}
}
}

INSTANTIATE_UTEST_SUITE_P(
/*no prefix*/, MockMessengerServiceFixture,
testing::Values(MiddlewareFlags{MiddlewareFlag::kErrorInRequestHook},
MiddlewareFlags{MiddlewareFlag::kErrorInResponseHook}));

USERVER_NAMESPACE_END

0 comments on commit be7b0e1

Please sign in to comment.