Skip to content

Commit

Permalink
ARROW-15974: [C++] Migrate flight/types.h header definitions to use R…
Browse files Browse the repository at this point in the history
…esult<>

Closes #12669 from zagto/flight-api-result-types

Authored-by: Tobias Zagorni <tobias@zagorni.eu>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
zagto authored and pitrou committed Mar 29, 2022
1 parent be45ec6 commit d214455
Show file tree
Hide file tree
Showing 37 changed files with 518 additions and 421 deletions.
2 changes: 1 addition & 1 deletion c_glib/arrow-flight-glib/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ gaflight_client_list_flights(GAFlightClient *client,
GList *listing = NULL;
std::unique_ptr<arrow::flight::FlightInfo> flight_info;
while (true) {
status = flight_listing->Next(&flight_info);
status = flight_listing->Next().Value(&flight_info);
if (!garrow::check(error,
status,
"[flight-client][list-flights]")) {
Expand Down
10 changes: 5 additions & 5 deletions c_glib/arrow-flight-glib/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ gaflight_location_new(const gchar *uri,
auto location = GAFLIGHT_LOCATION(g_object_new(GAFLIGHT_TYPE_LOCATION, NULL));
auto flight_location = gaflight_location_get_raw(location);
if (garrow::check(error,
arrow::flight::Location::Parse(uri, flight_location),
arrow::flight::Location::Parse(uri).Value(flight_location),
"[flight-location][new]")) {
return location;
} else {
Expand Down Expand Up @@ -1018,10 +1018,10 @@ gaflight_info_get_schema(GAFlightInfo *info,
std::shared_ptr<arrow::Schema> arrow_schema;
if (options) {
auto arrow_memo = garrow_read_options_get_dictionary_memo_raw(options);
status = flight_info->GetSchema(arrow_memo, &arrow_schema);
status = flight_info->GetSchema(arrow_memo).Value(&arrow_schema);
} else {
arrow::ipc::DictionaryMemo arrow_memo;
status = flight_info->GetSchema(&arrow_memo, &arrow_schema);
status = flight_info->GetSchema(&arrow_memo).Value(&arrow_schema);
}
if (garrow::check(error, status, "[flight-info][get-schema]")) {
return garrow_schema_new_raw(&arrow_schema);
Expand Down Expand Up @@ -1287,7 +1287,7 @@ gaflight_record_batch_reader_read_next(GAFlightRecordBatchReader *reader,
{
auto flight_reader = gaflight_record_batch_reader_get_raw(reader);
arrow::flight::FlightStreamChunk flight_chunk;
auto status = flight_reader->Next(&flight_chunk);
auto status = flight_reader->Next().Value(&flight_chunk);
if (garrow::check(error, status, "[flight-record-batch-reader][read-next]")) {
if (flight_chunk.data) {
return gaflight_stream_chunk_new_raw(&flight_chunk);
Expand All @@ -1314,7 +1314,7 @@ gaflight_record_batch_reader_read_all(GAFlightRecordBatchReader *reader,
{
auto flight_reader = gaflight_record_batch_reader_get_raw(reader);
std::shared_ptr<arrow::Table> arrow_table;
auto status = flight_reader->ReadAll(&arrow_table);
auto status = flight_reader->ToTable().Value(&arrow_table);
if (garrow::check(error, status, "[flight-record-batch-reader][read-all]")) {
return garrow_table_new_raw(&arrow_table);
} else {
Expand Down
4 changes: 2 additions & 2 deletions c_glib/arrow-flight-glib/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ namespace gaflight {
return stream->GetSchemaPayload(payload);
}

arrow::Status Next(arrow::flight::FlightPayload *payload) override {
arrow::Result<arrow::flight::FlightPayload> Next() override {
auto stream = gaflight_data_stream_get_raw(gastream_);
return stream->Next(payload);
return stream->Next();
}

private:
Expand Down
3 changes: 2 additions & 1 deletion cpp/examples/arrow/flight_grpc_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ int main(int argc, char** argv) {
server.reset(new SimpleFlightServer());

flight::Location bind_location;
ABORT_ON_FAILURE(flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location));
ABORT_ON_FAILURE(
flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port).Value(&bind_location));
flight::FlightServerOptions options(bind_location);

HelloWorldServiceImpl grpc_service;
Expand Down
7 changes: 3 additions & 4 deletions cpp/examples/arrow/flight_sql_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ DEFINE_int32(port, 31337, "The port of the Flight SQL server.");
DEFINE_string(query, "SELECT * FROM intTable WHERE value >= 0", "The query to execute.");

arrow::Status Main() {
flight::Location location;
ARROW_RETURN_NOT_OK(flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location));
ARROW_ASSIGN_OR_RAISE(auto location,
flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port));
std::cout << "Connecting to " << location.ToString() << std::endl;

// Set up the Flight SQL client
Expand All @@ -66,8 +66,7 @@ arrow::Status Main() {
ARROW_ASSIGN_OR_RAISE(auto stream, client->DoGet(call_options, endpoint.ticket));
// Read all results into an Arrow Table, though we can iteratively process record
// batches as they arrive as well
std::shared_ptr<arrow::Table> table;
ARROW_RETURN_NOT_OK(stream->ReadAll(&table));
ARROW_ASSIGN_OR_RAISE(auto table, stream->ToTable());
std::cout << "Read one chunk:" << std::endl;
std::cout << table->ToString() << std::endl;
}
Expand Down
71 changes: 44 additions & 27 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,21 @@ std::shared_ptr<FlightWriteSizeStatusDetail> FlightWriteSizeStatusDetail::Unwrap

FlightClientOptions FlightClientOptions::Defaults() { return FlightClientOptions(); }

arrow::Result<std::shared_ptr<Table>> FlightStreamReader::ToTable(
const StopToken& stop_token) {
ARROW_ASSIGN_OR_RAISE(auto batches, ToRecordBatches(stop_token));
ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
return Table::FromRecordBatches(schema, std::move(batches));
}

Status FlightStreamReader::ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
const StopToken& stop_token) {
return ToRecordBatches(stop_token).Value(batches);
}

Status FlightStreamReader::ReadAll(std::shared_ptr<Table>* table,
const StopToken& stop_token) {
std::vector<std::shared_ptr<RecordBatch>> batches;
RETURN_NOT_OK(ReadAll(&batches, stop_token));
ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
return Table::FromRecordBatches(schema, std::move(batches)).Value(table);
return ToTable(stop_token).Value(table);
}

/// \brief An ipc::MessageReader adapting the Flight ClientDataStream interface.
Expand Down Expand Up @@ -169,57 +178,60 @@ class ClientStreamReader : public FlightStreamReader {
RETURN_NOT_OK(EnsureDataStarted());
return batch_reader_->schema();
}
Status Next(FlightStreamChunk* out) override {
arrow::Result<FlightStreamChunk> Next() override {
FlightStreamChunk out;
internal::FlightData* data;
peekable_reader_->Peek(&data);
if (!data) {
out->app_metadata = nullptr;
out->data = nullptr;
return stream_->Finish(Status::OK());
out.app_metadata = nullptr;
out.data = nullptr;
RETURN_NOT_OK(stream_->Finish(Status::OK()));
return out;
}

if (!data->metadata) {
// Metadata-only (data->metadata is the IPC header)
out->app_metadata = data->app_metadata;
out->data = nullptr;
out.app_metadata = data->app_metadata;
out.data = nullptr;
peekable_reader_->Next(&data);
return Status::OK();
return out;
}

if (!batch_reader_) {
RETURN_NOT_OK(EnsureDataStarted());
// Re-peek here since EnsureDataStarted() advances the stream
return Next(out);
return Next();
}
auto status = batch_reader_->ReadNext(&out->data);
auto status = batch_reader_->ReadNext(&out.data);
if (ARROW_PREDICT_FALSE(!status.ok())) {
return stream_->Finish(std::move(status));
}
out->app_metadata = std::move(app_metadata_);
return Status::OK();
out.app_metadata = std::move(app_metadata_);
return out;
}
Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches) override {
return ReadAll(batches, stop_token_);
arrow::Result<std::vector<std::shared_ptr<RecordBatch>>> ToRecordBatches() override {
return ToRecordBatches(stop_token_);
}
Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
const StopToken& stop_token) override {
arrow::Result<std::vector<std::shared_ptr<RecordBatch>>> ToRecordBatches(
const StopToken& stop_token) override {
std::vector<std::shared_ptr<RecordBatch>> batches;
FlightStreamChunk chunk;

while (true) {
if (stop_token.IsStopRequested()) {
Cancel();
return stop_token.Poll();
}
RETURN_NOT_OK(Next(&chunk));
ARROW_ASSIGN_OR_RAISE(chunk, Next());
if (!chunk.data) break;
batches->emplace_back(std::move(chunk.data));
batches.emplace_back(std::move(chunk.data));
}
return Status::OK();
return batches;
}
Status ReadAll(std::shared_ptr<Table>* table) override {
return ReadAll(table, stop_token_);
arrow::Result<std::shared_ptr<Table>> ToTable() override {
return ToTable(stop_token_);
}
using FlightStreamReader::ReadAll;
using FlightStreamReader::ToTable;
void Cancel() override { stream_->TryCancel(); }

private:
Expand Down Expand Up @@ -526,11 +538,16 @@ Status FlightClient::GetFlightInfo(const FlightCallOptions& options,
return transport_->GetFlightInfo(options, descriptor, info);
}

arrow::Result<std::unique_ptr<SchemaResult>> FlightClient::GetSchema(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
RETURN_NOT_OK(CheckOpen());
return transport_->GetSchema(options, descriptor);
}

Status FlightClient::GetSchema(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::unique_ptr<SchemaResult>* schema_result) {
RETURN_NOT_OK(CheckOpen());
return transport_->GetSchema(options, descriptor, schema_result);
return GetSchema(options, descriptor).Value(schema_result);
}

Status FlightClient::ListFlights(std::unique_ptr<FlightListing>* listing) {
Expand Down
32 changes: 26 additions & 6 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,22 @@ class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader
public:
/// \brief Try to cancel the call.
virtual void Cancel() = 0;
using MetadataRecordBatchReader::ReadAll;

using MetadataRecordBatchReader::ToRecordBatches;
/// \brief Consume entire stream as a vector of record batches
virtual Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
const StopToken& stop_token) = 0;
virtual arrow::Result<std::vector<std::shared_ptr<RecordBatch>>> ToRecordBatches(
const StopToken& stop_token) = 0;

using MetadataRecordBatchReader::ReadAll;
ARROW_DEPRECATED("Deprecated in 8.0.0. Use ToRecordBatches instead.")
Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches,
const StopToken& stop_token);

using MetadataRecordBatchReader::ToTable;
/// \brief Consume entire stream as a Table
arrow::Result<std::shared_ptr<Table>> ToTable(const StopToken& stop_token);

ARROW_DEPRECATED("Deprecated in 8.0.0. Use ToTable instead.")
Status ReadAll(std::shared_ptr<Table>* table, const StopToken& stop_token);
};

Expand Down Expand Up @@ -253,13 +264,22 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[in] options Per-RPC options
/// \param[in] descriptor the dataset request, whether a named dataset or
/// command
/// \param[out] schema_result the SchemaResult describing the dataset schema
/// \return Status
/// \return Arrow result with the SchemaResult describing the dataset schema
arrow::Result<std::unique_ptr<SchemaResult>> GetSchema(
const FlightCallOptions& options, const FlightDescriptor& descriptor);

ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor,
std::unique_ptr<SchemaResult>* schema_result);

arrow::Result<std::unique_ptr<SchemaResult>> GetSchema(
const FlightDescriptor& descriptor) {
return GetSchema({}, descriptor);
}
ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status GetSchema(const FlightDescriptor& descriptor,
std::unique_ptr<SchemaResult>* schema_result) {
return GetSchema({}, descriptor, schema_result);
return GetSchema({}, descriptor).Value(schema_result);
}

/// \brief List all available flights known to the server
Expand Down
18 changes: 10 additions & 8 deletions cpp/src/arrow/flight/flight_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ arrow::Result<PerformanceResult> RunDoGetTest(FlightClient* client,
StopWatch timer;
while (true) {
timer.Start();
RETURN_NOT_OK(reader->Next(&batch));
ARROW_ASSIGN_OR_RAISE(batch, reader->Next());
stats->AddLatency(timer.Stop());
if (!batch.data) {
break;
Expand Down Expand Up @@ -287,9 +287,8 @@ Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_op
RETURN_NOT_OK(client->GetFlightInfo(call_options, descriptor, &plan));

// Read the streams in parallel
std::shared_ptr<Schema> schema;
ipc::DictionaryMemo dict_memo;
RETURN_NOT_OK(plan->GetSchema(&dict_memo, &schema));
ARROW_ASSIGN_OR_RAISE(auto schema, plan->GetSchema(&dict_memo));

int64_t start_total_records = stats->total_records;

Expand Down Expand Up @@ -457,7 +456,8 @@ int main(int argc, char** argv) {
std::cout << "Using standalone Unix server" << std::endl;
}
std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location));
ABORT_NOT_OK(
arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix).Value(&location));
} else {
if (FLAGS_server_host == "") {
FLAGS_server_host = "localhost";
Expand Down Expand Up @@ -488,11 +488,13 @@ int main(int argc, char** argv) {
std::cout << "Server host: " << FLAGS_server_host << std::endl
<< "Server port: " << FLAGS_server_port << std::endl;
if (FLAGS_cert_file.empty()) {
ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host,
FLAGS_server_port, &location));
ABORT_NOT_OK(
arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_server_port)
.Value(&location));
} else {
ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host,
FLAGS_server_port, &location));
ABORT_NOT_OK(
arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_server_port)
.Value(&location));
options.disable_server_verification = true;
}
}
Expand Down
49 changes: 42 additions & 7 deletions cpp/src/arrow/flight/flight_internals_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ TEST(FlightTypes, FlightDescriptorToFromProto) {
// ARROW-6017: we should be able to construct locations for unknown
// schemes
TEST(FlightTypes, LocationUnknownScheme) {
Location location;
ASSERT_OK(Location::Parse("s3://test", &location));
ASSERT_OK(Location::Parse("https://example.com/foo", &location));
ASSERT_OK(Location::Parse("s3://test"));
ASSERT_OK(Location::Parse("https://example.com/foo"));
}

TEST(FlightTypes, RoundTripTypes) {
Expand All @@ -105,10 +104,9 @@ TEST(FlightTypes, RoundTripTypes) {
std::shared_ptr<Schema> schema =
arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
field("d", int64())});
Location location1, location2, location3;
ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1));
ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2));
ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3));
ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("localhost", 10010));
ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTls("localhost", 10010));
ASSERT_OK_AND_ASSIGN(auto location3, Location::ForGrpcUnix("/tmp/test.sock"));
std::vector<FlightEndpoint> endpoints{FlightEndpoint{ticket, {location1, location2}},
FlightEndpoint{ticket, {location3}}};
ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data));
Expand Down Expand Up @@ -177,6 +175,43 @@ TEST(FlightTypes, RoundtripStatus) {
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
}

TEST(FlightTypes, LocationConstruction) {
ASSERT_RAISES(Invalid, Location::Parse("This is not an URI").status());
ASSERT_RAISES(Invalid, Location::ForGrpcTcp("This is not a hostname", 12345).status());
ASSERT_RAISES(Invalid, Location::ForGrpcTls("This is not a hostname", 12345).status());
ASSERT_RAISES(Invalid, Location::ForGrpcUnix("This is not a filename").status());

ASSERT_OK_AND_ASSIGN(auto location, Location::Parse("s3://test"));
ASSERT_EQ(location.ToString(), "s3://test");
ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", 12345));
ASSERT_EQ(location.ToString(), "grpc+tcp://localhost:12345");
ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTls("localhost", 12345));
ASSERT_EQ(location.ToString(), "grpc+tls://localhost:12345");
ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcUnix("/tmp/test.sock"));
ASSERT_EQ(location.ToString(), "grpc+unix:///tmp/test.sock");
}

ARROW_SUPPRESS_DEPRECATION_WARNING
TEST(FlightTypes, DeprecatedLocationConstruction) {
Location location;
ASSERT_RAISES(Invalid, Location::Parse("This is not an URI", &location));
ASSERT_RAISES(Invalid,
Location::ForGrpcTcp("This is not a hostname", 12345, &location));
ASSERT_RAISES(Invalid,
Location::ForGrpcTls("This is not a hostname", 12345, &location));
ASSERT_RAISES(Invalid, Location::ForGrpcUnix("This is not a filename", &location));

ASSERT_OK(Location::Parse("s3://test", &location));
ASSERT_EQ(location.ToString(), "s3://test");
ASSERT_OK(Location::ForGrpcTcp("localhost", 12345, &location));
ASSERT_EQ(location.ToString(), "grpc+tcp://localhost:12345");
ASSERT_OK(Location::ForGrpcTls("localhost", 12345, &location));
ASSERT_EQ(location.ToString(), "grpc+tls://localhost:12345");
ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location));
ASSERT_EQ(location.ToString(), "grpc+unix:///tmp/test.sock");
}
ARROW_UNSUPPRESS_DEPRECATION_WARNING

// ----------------------------------------------------------------------
// Cookie authentication/middleware

Expand Down
Loading

0 comments on commit d214455

Please sign in to comment.