From 2252e64829192c2f4901dcf0b21fe75b06e5db04 Mon Sep 17 00:00:00 2001 From: Paul Nienaber Date: Tue, 13 Feb 2024 14:45:14 -0800 Subject: [PATCH] Fix DoPut, DoExchange, and cover case of no result stream writes by handler --- .../flight/transport/grpc/grpc_server.cc | 71 ++++++++++++------- .../flight/transport/grpc/util_internal.h | 1 + 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc index 97d02f96a32f7..443cdf6bb05cc 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc @@ -59,6 +59,12 @@ using ServerWriter = ::grpc::ServerWriter; const auto& __s = (STATUS); \ return CONTEXT.FinishRequest(__s); \ } while (false) +#define RETURN_WITH_CALL_AND_MIDDLEWARE(CONTEXT, CALL, STATUS) \ + do { \ + const auto& __s = (STATUS); \ + (CALL); \ + return CONTEXT.FinishRequest(__s); \ + } while (false) #define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \ if (VAL == nullptr) { \ RETURN_WITH_MIDDLEWARE( \ @@ -238,6 +244,7 @@ class GetDataStream : public internal::ServerDataStream { return WritePayload(payload, writer_); } + void TryCallOnceBeforeWrite() { once_before_write_(); } private: ServerWriter* writer_; @@ -268,6 +275,7 @@ class PutDataStream final : public internal::ServerDataStream { } return Status::IOError("Unknown error writing metadata."); } + void TryCallOnceBeforeWrite() { once_before_write_(); } private: ::grpc::ServerReaderWriter* stream_; @@ -291,6 +299,7 @@ class ExchangeDataStream final : public internal::ServerDataStream { arrow::Result WriteData(const FlightPayload& payload) override { return WritePayload(payload, stream_); } + void TryCallOnceBeforeWrite() { once_before_write_(); } private: ::grpc::ServerReaderWriter* stream_; @@ -451,12 +460,13 @@ class GrpcServiceHandler final : public FlightService::Service { Criteria criteria; if (request) { - SERVICE_CALL_THEN_RETURN_NOT_OK(flight_context, addMiddlewareHeaders(context, flight_context), + SERVICE_CALL_THEN_RETURN_NOT_OK(flight_context, + addMiddlewareHeaders(context, flight_context), internal::FromProto(*request, &criteria)); } - auto res = impl_->base()->ListFlights(flight_context, &criteria, &listing); - addMiddlewareHeaders(context, flight_context); - SERVICE_RETURN_NOT_OK(flight_context, res); + SERVICE_CALL_THEN_RETURN_NOT_OK( + flight_context, addMiddlewareHeaders(context, flight_context), + impl_->base()->ListFlights(flight_context, &criteria, &listing)); if (!listing) { // Treat null listing as no flights available RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::OK); @@ -473,7 +483,8 @@ class GrpcServiceHandler final : public FlightService::Service { addMiddlewareHeaders(context, flight_context), CheckAuth(FlightMethod::GetFlightInfo, context, flight_context, true)); - CHECK_ARG_NOT_NULL_WITH_CALL(flight_context, request, "FlightDescriptor cannot be null", + CHECK_ARG_NOT_NULL_WITH_CALL(flight_context, request, + "FlightDescriptor cannot be null", addMiddlewareHeaders(context, flight_context)); FlightDescriptor descr; @@ -482,9 +493,9 @@ class GrpcServiceHandler final : public FlightService::Service { internal::FromProto(*request, &descr)); std::unique_ptr info; - auto res = impl_->base()->GetFlightInfo(flight_context, descr, &info); - addMiddlewareHeaders(context, flight_context); - SERVICE_RETURN_NOT_OK(flight_context, res); + SERVICE_CALL_THEN_RETURN_NOT_OK( + flight_context, addMiddlewareHeaders(context, flight_context), + impl_->base()->GetFlightInfo(flight_context, descr, &info)); if (!info) { // Treat null listing as no flights available @@ -514,9 +525,9 @@ class GrpcServiceHandler final : public FlightService::Service { internal::FromProto(*request, &descr)); std::unique_ptr info; - auto res = impl_->base()->PollFlightInfo(flight_context, descr, &info); - addMiddlewareHeaders(context, flight_context); - SERVICE_RETURN_NOT_OK(flight_context, res); + SERVICE_CALL_THEN_RETURN_NOT_OK( + flight_context, addMiddlewareHeaders(context, flight_context), + impl_->base()->PollFlightInfo(flight_context, descr, &info)); if (!info) { // Treat null listing as no flights available @@ -545,9 +556,9 @@ class GrpcServiceHandler final : public FlightService::Service { internal::FromProto(*request, &descr)); std::unique_ptr result; - auto res = impl_->base()->GetSchema(flight_context, descr, &result); - addMiddlewareHeaders(context, flight_context); - SERVICE_RETURN_NOT_OK(flight_context, res); + SERVICE_CALL_THEN_RETURN_NOT_OK( + flight_context, addMiddlewareHeaders(context, flight_context), + impl_->base()->GetSchema(flight_context, descr, &result)); if (!result) { // Treat null listing as no flights available @@ -576,29 +587,37 @@ class GrpcServiceHandler final : public FlightService::Service { GetDataStream stream(writer, [&]() { addMiddlewareHeaders(context, flight_context); }); - RETURN_WITH_MIDDLEWARE(flight_context, - impl_->DoGet(flight_context, std::move(ticket), &stream)); + RETURN_WITH_CALL_AND_MIDDLEWARE( + flight_context, stream.TryCallOnceBeforeWrite(), + impl_->DoGet(flight_context, std::move(ticket), &stream)); } ::grpc::Status DoPut( ServerContext* context, ::grpc::ServerReaderWriter* reader) { GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context)); + GRPC_CALL_THEN_RETURN_NOT_GRPC_OK( + addMiddlewareHeaders(context, flight_context), + CheckAuth(FlightMethod::DoPut, context, flight_context, true)); - PutDataStream stream(reader); - RETURN_WITH_MIDDLEWARE(flight_context, impl_->DoPut(flight_context, &stream)); + PutDataStream stream(reader, + [&]() { addMiddlewareHeaders(context, flight_context); }); + RETURN_WITH_CALL_AND_MIDDLEWARE(flight_context, stream.TryCallOnceBeforeWrite(), + impl_->DoPut(flight_context, &stream)); } ::grpc::Status DoExchange( ServerContext* context, ::grpc::ServerReaderWriter* stream) { GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context)); + GRPC_CALL_THEN_RETURN_NOT_GRPC_OK( + addMiddlewareHeaders(context, flight_context), + CheckAuth(FlightMethod::DoExchange, context, flight_context, true)); - ExchangeDataStream data_stream(stream); - RETURN_WITH_MIDDLEWARE(flight_context, - impl_->DoExchange(flight_context, &data_stream)); + ExchangeDataStream data_stream( + stream, [&]() { addMiddlewareHeaders(context, flight_context); }); + RETURN_WITH_CALL_AND_MIDDLEWARE(flight_context, data_stream.TryCallOnceBeforeWrite(), + impl_->DoExchange(flight_context, &data_stream)); } ::grpc::Status ListActions(ServerContext* context, const pb::Empty* request, @@ -629,9 +648,9 @@ class GrpcServiceHandler final : public FlightService::Service { internal::FromProto(*request, &action)); std::unique_ptr results; - auto res = impl_->base()->DoAction(flight_context, action, &results); - addMiddlewareHeaders(context, flight_context); - SERVICE_RETURN_NOT_OK(flight_context, res); + SERVICE_CALL_THEN_RETURN_NOT_OK( + flight_context, addMiddlewareHeaders(context, flight_context), + impl_->base()->DoAction(flight_context, action, &results)); if (!results) { RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::CANCELLED); diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.h b/cpp/src/arrow/flight/transport/grpc/util_internal.h index 31bcfe7edee6f..44fcd82ccea39 100644 --- a/cpp/src/arrow/flight/transport/grpc/util_internal.h +++ b/cpp/src/arrow/flight/transport/grpc/util_internal.h @@ -40,6 +40,7 @@ namespace flight { if (ARROW_PREDICT_FALSE(!_s.ok())) { \ (call); \ return _s; \ + } \ } while (0) #define GRPC_CALL_THEN_RETURN_NOT_GRPC_OK(call, expr) \