Skip to content

Commit

Permalink
Fix DoPut, DoExchange, and cover case of no result stream writes by h…
Browse files Browse the repository at this point in the history
…andler
  • Loading branch information
Paul Nienaber committed Feb 14, 2024
1 parent 49782b9 commit 2252e64
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
71 changes: 45 additions & 26 deletions cpp/src/arrow/flight/transport/grpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ using ServerWriter = ::grpc::ServerWriter<T>;
const auto& __s = (STATUS); \
return CONTEXT.FinishRequest(__s); \
} while (false)
#define RETURN_WITH_CALL_AND_MIDDLEWARE(CONTEXT, CALL, STATUS) \
do { \
const auto& __s = (STATUS); \
(CALL); \
return CONTEXT.FinishRequest(__s); \
} while (false)
#define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \
if (VAL == nullptr) { \
RETURN_WITH_MIDDLEWARE( \
Expand Down Expand Up @@ -238,6 +244,7 @@ class GetDataStream : public internal::ServerDataStream {

return WritePayload(payload, writer_);
}
void TryCallOnceBeforeWrite() { once_before_write_(); }

private:
ServerWriter<pb::FlightData>* writer_;
Expand Down Expand Up @@ -268,6 +275,7 @@ class PutDataStream final : public internal::ServerDataStream {
}
return Status::IOError("Unknown error writing metadata.");
}
void TryCallOnceBeforeWrite() { once_before_write_(); }

private:
::grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* stream_;
Expand All @@ -291,6 +299,7 @@ class ExchangeDataStream final : public internal::ServerDataStream {
arrow::Result<bool> WriteData(const FlightPayload& payload) override {
return WritePayload(payload, stream_);
}
void TryCallOnceBeforeWrite() { once_before_write_(); }

private:
::grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream_;
Expand Down Expand Up @@ -451,12 +460,13 @@ class GrpcServiceHandler final : public FlightService::Service {

Criteria criteria;
if (request) {
SERVICE_CALL_THEN_RETURN_NOT_OK(flight_context, addMiddlewareHeaders(context, flight_context),
SERVICE_CALL_THEN_RETURN_NOT_OK(flight_context,
addMiddlewareHeaders(context, flight_context),
internal::FromProto(*request, &criteria));
}
auto res = impl_->base()->ListFlights(flight_context, &criteria, &listing);
addMiddlewareHeaders(context, flight_context);
SERVICE_RETURN_NOT_OK(flight_context, res);
SERVICE_CALL_THEN_RETURN_NOT_OK(
flight_context, addMiddlewareHeaders(context, flight_context),
impl_->base()->ListFlights(flight_context, &criteria, &listing));
if (!listing) {
// Treat null listing as no flights available
RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::OK);
Expand All @@ -473,7 +483,8 @@ class GrpcServiceHandler final : public FlightService::Service {
addMiddlewareHeaders(context, flight_context),
CheckAuth(FlightMethod::GetFlightInfo, context, flight_context, true));

CHECK_ARG_NOT_NULL_WITH_CALL(flight_context, request, "FlightDescriptor cannot be null",
CHECK_ARG_NOT_NULL_WITH_CALL(flight_context, request,
"FlightDescriptor cannot be null",
addMiddlewareHeaders(context, flight_context));

FlightDescriptor descr;
Expand All @@ -482,9 +493,9 @@ class GrpcServiceHandler final : public FlightService::Service {
internal::FromProto(*request, &descr));

std::unique_ptr<FlightInfo> info;
auto res = impl_->base()->GetFlightInfo(flight_context, descr, &info);
addMiddlewareHeaders(context, flight_context);
SERVICE_RETURN_NOT_OK(flight_context, res);
SERVICE_CALL_THEN_RETURN_NOT_OK(
flight_context, addMiddlewareHeaders(context, flight_context),
impl_->base()->GetFlightInfo(flight_context, descr, &info));

if (!info) {
// Treat null listing as no flights available
Expand Down Expand Up @@ -514,9 +525,9 @@ class GrpcServiceHandler final : public FlightService::Service {
internal::FromProto(*request, &descr));

std::unique_ptr<PollInfo> info;
auto res = impl_->base()->PollFlightInfo(flight_context, descr, &info);
addMiddlewareHeaders(context, flight_context);
SERVICE_RETURN_NOT_OK(flight_context, res);
SERVICE_CALL_THEN_RETURN_NOT_OK(
flight_context, addMiddlewareHeaders(context, flight_context),
impl_->base()->PollFlightInfo(flight_context, descr, &info));

if (!info) {
// Treat null listing as no flights available
Expand Down Expand Up @@ -545,9 +556,9 @@ class GrpcServiceHandler final : public FlightService::Service {
internal::FromProto(*request, &descr));

std::unique_ptr<SchemaResult> result;
auto res = impl_->base()->GetSchema(flight_context, descr, &result);
addMiddlewareHeaders(context, flight_context);
SERVICE_RETURN_NOT_OK(flight_context, res);
SERVICE_CALL_THEN_RETURN_NOT_OK(
flight_context, addMiddlewareHeaders(context, flight_context),
impl_->base()->GetSchema(flight_context, descr, &result));

if (!result) {
// Treat null listing as no flights available
Expand Down Expand Up @@ -576,29 +587,37 @@ class GrpcServiceHandler final : public FlightService::Service {

GetDataStream stream(writer,
[&]() { addMiddlewareHeaders(context, flight_context); });
RETURN_WITH_MIDDLEWARE(flight_context,
impl_->DoGet(flight_context, std::move(ticket), &stream));
RETURN_WITH_CALL_AND_MIDDLEWARE(
flight_context, stream.TryCallOnceBeforeWrite(),
impl_->DoGet(flight_context, std::move(ticket), &stream));
}

::grpc::Status DoPut(
ServerContext* context,
::grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader) {
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context));
GRPC_CALL_THEN_RETURN_NOT_GRPC_OK(
addMiddlewareHeaders(context, flight_context),
CheckAuth(FlightMethod::DoPut, context, flight_context, true));

PutDataStream stream(reader);
RETURN_WITH_MIDDLEWARE(flight_context, impl_->DoPut(flight_context, &stream));
PutDataStream stream(reader,
[&]() { addMiddlewareHeaders(context, flight_context); });
RETURN_WITH_CALL_AND_MIDDLEWARE(flight_context, stream.TryCallOnceBeforeWrite(),
impl_->DoPut(flight_context, &stream));
}

::grpc::Status DoExchange(
ServerContext* context,
::grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream) {
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context));
GRPC_CALL_THEN_RETURN_NOT_GRPC_OK(
addMiddlewareHeaders(context, flight_context),
CheckAuth(FlightMethod::DoExchange, context, flight_context, true));

ExchangeDataStream data_stream(stream);
RETURN_WITH_MIDDLEWARE(flight_context,
impl_->DoExchange(flight_context, &data_stream));
ExchangeDataStream data_stream(
stream, [&]() { addMiddlewareHeaders(context, flight_context); });
RETURN_WITH_CALL_AND_MIDDLEWARE(flight_context, data_stream.TryCallOnceBeforeWrite(),
impl_->DoExchange(flight_context, &data_stream));
}

::grpc::Status ListActions(ServerContext* context, const pb::Empty* request,
Expand Down Expand Up @@ -629,9 +648,9 @@ class GrpcServiceHandler final : public FlightService::Service {
internal::FromProto(*request, &action));

std::unique_ptr<ResultStream> results;
auto res = impl_->base()->DoAction(flight_context, action, &results);
addMiddlewareHeaders(context, flight_context);
SERVICE_RETURN_NOT_OK(flight_context, res);
SERVICE_CALL_THEN_RETURN_NOT_OK(
flight_context, addMiddlewareHeaders(context, flight_context),
impl_->base()->DoAction(flight_context, action, &results));

if (!results) {
RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::CANCELLED);
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/transport/grpc/util_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace flight {
if (ARROW_PREDICT_FALSE(!_s.ok())) { \
(call); \
return _s; \
} \
} while (0)

#define GRPC_CALL_THEN_RETURN_NOT_GRPC_OK(call, expr) \
Expand Down

0 comments on commit 2252e64

Please sign in to comment.