diff --git a/c_glib/arrow-flight-glib/client.cpp b/c_glib/arrow-flight-glib/client.cpp index c0be5b8c0ff5d..b4de6468c6514 100644 --- a/c_glib/arrow-flight-glib/client.cpp +++ b/c_glib/arrow-flight-glib/client.cpp @@ -326,7 +326,7 @@ gaflight_client_list_flights(GAFlightClient *client, GList *listing = NULL; std::unique_ptr flight_info; while (true) { - status = flight_listing->Next(&flight_info); + status = flight_listing->Next().Value(&flight_info); if (!garrow::check(error, status, "[flight-client][list-flights]")) { diff --git a/c_glib/arrow-flight-glib/common.cpp b/c_glib/arrow-flight-glib/common.cpp index 81b00f7a36919..0365096af310a 100644 --- a/c_glib/arrow-flight-glib/common.cpp +++ b/c_glib/arrow-flight-glib/common.cpp @@ -245,7 +245,7 @@ gaflight_location_new(const gchar *uri, auto location = GAFLIGHT_LOCATION(g_object_new(GAFLIGHT_TYPE_LOCATION, NULL)); auto flight_location = gaflight_location_get_raw(location); if (garrow::check(error, - arrow::flight::Location::Parse(uri, flight_location), + arrow::flight::Location::Parse(uri).Value(flight_location), "[flight-location][new]")) { return location; } else { @@ -1018,10 +1018,10 @@ gaflight_info_get_schema(GAFlightInfo *info, std::shared_ptr arrow_schema; if (options) { auto arrow_memo = garrow_read_options_get_dictionary_memo_raw(options); - status = flight_info->GetSchema(arrow_memo, &arrow_schema); + status = flight_info->GetSchema(arrow_memo).Value(&arrow_schema); } else { arrow::ipc::DictionaryMemo arrow_memo; - status = flight_info->GetSchema(&arrow_memo, &arrow_schema); + status = flight_info->GetSchema(&arrow_memo).Value(&arrow_schema); } if (garrow::check(error, status, "[flight-info][get-schema]")) { return garrow_schema_new_raw(&arrow_schema); @@ -1287,7 +1287,7 @@ gaflight_record_batch_reader_read_next(GAFlightRecordBatchReader *reader, { auto flight_reader = gaflight_record_batch_reader_get_raw(reader); arrow::flight::FlightStreamChunk flight_chunk; - auto status = flight_reader->Next(&flight_chunk); + auto status = flight_reader->Next().Value(&flight_chunk); if (garrow::check(error, status, "[flight-record-batch-reader][read-next]")) { if (flight_chunk.data) { return gaflight_stream_chunk_new_raw(&flight_chunk); @@ -1314,7 +1314,7 @@ gaflight_record_batch_reader_read_all(GAFlightRecordBatchReader *reader, { auto flight_reader = gaflight_record_batch_reader_get_raw(reader); std::shared_ptr arrow_table; - auto status = flight_reader->ReadAll(&arrow_table); + auto status = flight_reader->ToTable().Value(&arrow_table); if (garrow::check(error, status, "[flight-record-batch-reader][read-all]")) { return garrow_table_new_raw(&arrow_table); } else { diff --git a/c_glib/arrow-flight-glib/server.cpp b/c_glib/arrow-flight-glib/server.cpp index eb05284c14e05..c9820821a4299 100644 --- a/c_glib/arrow-flight-glib/server.cpp +++ b/c_glib/arrow-flight-glib/server.cpp @@ -456,9 +456,9 @@ namespace gaflight { return stream->GetSchemaPayload(payload); } - arrow::Status Next(arrow::flight::FlightPayload *payload) override { + arrow::Result Next() override { auto stream = gaflight_data_stream_get_raw(gastream_); - return stream->Next(payload); + return stream->Next(); } private: diff --git a/cpp/examples/arrow/flight_grpc_example.cc b/cpp/examples/arrow/flight_grpc_example.cc index db9cc177a5f74..596f3f71f9eb2 100644 --- a/cpp/examples/arrow/flight_grpc_example.cc +++ b/cpp/examples/arrow/flight_grpc_example.cc @@ -97,7 +97,8 @@ int main(int argc, char** argv) { server.reset(new SimpleFlightServer()); flight::Location bind_location; - ABORT_ON_FAILURE(flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location)); + ABORT_ON_FAILURE( + flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port).Value(&bind_location)); flight::FlightServerOptions options(bind_location); HelloWorldServiceImpl grpc_service; diff --git a/cpp/examples/arrow/flight_sql_example.cc b/cpp/examples/arrow/flight_sql_example.cc index f52336a44b0e2..5dfd97dbf1c81 100644 --- a/cpp/examples/arrow/flight_sql_example.cc +++ b/cpp/examples/arrow/flight_sql_example.cc @@ -38,8 +38,8 @@ DEFINE_int32(port, 31337, "The port of the Flight SQL server."); DEFINE_string(query, "SELECT * FROM intTable WHERE value >= 0", "The query to execute."); arrow::Status Main() { - flight::Location location; - ARROW_RETURN_NOT_OK(flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location)); + ARROW_ASSIGN_OR_RAISE(auto location, + flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port)); std::cout << "Connecting to " << location.ToString() << std::endl; // Set up the Flight SQL client @@ -66,8 +66,7 @@ arrow::Status Main() { ARROW_ASSIGN_OR_RAISE(auto stream, client->DoGet(call_options, endpoint.ticket)); // Read all results into an Arrow Table, though we can iteratively process record // batches as they arrive as well - std::shared_ptr table; - ARROW_RETURN_NOT_OK(stream->ReadAll(&table)); + ARROW_ASSIGN_OR_RAISE(auto table, stream->ToTable()); std::cout << "Read one chunk:" << std::endl; std::cout << table->ToString() << std::endl; } diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 3aabe37ebcf7b..160387b1663a8 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -72,12 +72,21 @@ std::shared_ptr FlightWriteSizeStatusDetail::Unwrap FlightClientOptions FlightClientOptions::Defaults() { return FlightClientOptions(); } +arrow::Result> FlightStreamReader::ToTable( + const StopToken& stop_token) { + ARROW_ASSIGN_OR_RAISE(auto batches, ToRecordBatches(stop_token)); + ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema()); + return Table::FromRecordBatches(schema, std::move(batches)); +} + +Status FlightStreamReader::ReadAll(std::vector>* batches, + const StopToken& stop_token) { + return ToRecordBatches(stop_token).Value(batches); +} + Status FlightStreamReader::ReadAll(std::shared_ptr* table, const StopToken& stop_token) { - std::vector> batches; - RETURN_NOT_OK(ReadAll(&batches, stop_token)); - ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema()); - return Table::FromRecordBatches(schema, std::move(batches)).Value(table); + return ToTable(stop_token).Value(table); } /// \brief An ipc::MessageReader adapting the Flight ClientDataStream interface. @@ -169,40 +178,43 @@ class ClientStreamReader : public FlightStreamReader { RETURN_NOT_OK(EnsureDataStarted()); return batch_reader_->schema(); } - Status Next(FlightStreamChunk* out) override { + arrow::Result Next() override { + FlightStreamChunk out; internal::FlightData* data; peekable_reader_->Peek(&data); if (!data) { - out->app_metadata = nullptr; - out->data = nullptr; - return stream_->Finish(Status::OK()); + out.app_metadata = nullptr; + out.data = nullptr; + RETURN_NOT_OK(stream_->Finish(Status::OK())); + return out; } if (!data->metadata) { // Metadata-only (data->metadata is the IPC header) - out->app_metadata = data->app_metadata; - out->data = nullptr; + out.app_metadata = data->app_metadata; + out.data = nullptr; peekable_reader_->Next(&data); - return Status::OK(); + return out; } if (!batch_reader_) { RETURN_NOT_OK(EnsureDataStarted()); // Re-peek here since EnsureDataStarted() advances the stream - return Next(out); + return Next(); } - auto status = batch_reader_->ReadNext(&out->data); + auto status = batch_reader_->ReadNext(&out.data); if (ARROW_PREDICT_FALSE(!status.ok())) { return stream_->Finish(std::move(status)); } - out->app_metadata = std::move(app_metadata_); - return Status::OK(); + out.app_metadata = std::move(app_metadata_); + return out; } - Status ReadAll(std::vector>* batches) override { - return ReadAll(batches, stop_token_); + arrow::Result>> ToRecordBatches() override { + return ToRecordBatches(stop_token_); } - Status ReadAll(std::vector>* batches, - const StopToken& stop_token) override { + arrow::Result>> ToRecordBatches( + const StopToken& stop_token) override { + std::vector> batches; FlightStreamChunk chunk; while (true) { @@ -210,16 +222,16 @@ class ClientStreamReader : public FlightStreamReader { Cancel(); return stop_token.Poll(); } - RETURN_NOT_OK(Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, Next()); if (!chunk.data) break; - batches->emplace_back(std::move(chunk.data)); + batches.emplace_back(std::move(chunk.data)); } - return Status::OK(); + return batches; } - Status ReadAll(std::shared_ptr
* table) override { - return ReadAll(table, stop_token_); + arrow::Result> ToTable() override { + return ToTable(stop_token_); } - using FlightStreamReader::ReadAll; + using FlightStreamReader::ToTable; void Cancel() override { stream_->TryCancel(); } private: @@ -526,11 +538,16 @@ Status FlightClient::GetFlightInfo(const FlightCallOptions& options, return transport_->GetFlightInfo(options, descriptor, info); } +arrow::Result> FlightClient::GetSchema( + const FlightCallOptions& options, const FlightDescriptor& descriptor) { + RETURN_NOT_OK(CheckOpen()); + return transport_->GetSchema(options, descriptor); +} + Status FlightClient::GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, std::unique_ptr* schema_result) { - RETURN_NOT_OK(CheckOpen()); - return transport_->GetSchema(options, descriptor, schema_result); + return GetSchema(options, descriptor).Value(schema_result); } Status FlightClient::ListFlights(std::unique_ptr* listing) { diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index c5ed60a6c42ac..06d87bb9aebb3 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -133,11 +133,22 @@ class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader public: /// \brief Try to cancel the call. virtual void Cancel() = 0; - using MetadataRecordBatchReader::ReadAll; + + using MetadataRecordBatchReader::ToRecordBatches; /// \brief Consume entire stream as a vector of record batches - virtual Status ReadAll(std::vector>* batches, - const StopToken& stop_token) = 0; + virtual arrow::Result>> ToRecordBatches( + const StopToken& stop_token) = 0; + + using MetadataRecordBatchReader::ReadAll; + ARROW_DEPRECATED("Deprecated in 8.0.0. Use ToRecordBatches instead.") + Status ReadAll(std::vector>* batches, + const StopToken& stop_token); + + using MetadataRecordBatchReader::ToTable; /// \brief Consume entire stream as a Table + arrow::Result> ToTable(const StopToken& stop_token); + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use ToTable instead.") Status ReadAll(std::shared_ptr
* table, const StopToken& stop_token); }; @@ -253,13 +264,22 @@ class ARROW_FLIGHT_EXPORT FlightClient { /// \param[in] options Per-RPC options /// \param[in] descriptor the dataset request, whether a named dataset or /// command - /// \param[out] schema_result the SchemaResult describing the dataset schema - /// \return Status + /// \return Arrow result with the SchemaResult describing the dataset schema + arrow::Result> GetSchema( + const FlightCallOptions& options, const FlightDescriptor& descriptor); + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, std::unique_ptr* schema_result); + + arrow::Result> GetSchema( + const FlightDescriptor& descriptor) { + return GetSchema({}, descriptor); + } + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") Status GetSchema(const FlightDescriptor& descriptor, std::unique_ptr* schema_result) { - return GetSchema({}, descriptor, schema_result); + return GetSchema({}, descriptor).Value(schema_result); } /// \brief List all available flights known to the server diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc index 6649de52cd910..2d011f78730d7 100644 --- a/cpp/src/arrow/flight/flight_benchmark.cc +++ b/cpp/src/arrow/flight/flight_benchmark.cc @@ -154,7 +154,7 @@ arrow::Result RunDoGetTest(FlightClient* client, StopWatch timer; while (true) { timer.Start(); - RETURN_NOT_OK(reader->Next(&batch)); + ARROW_ASSIGN_OR_RAISE(batch, reader->Next()); stats->AddLatency(timer.Stop()); if (!batch.data) { break; @@ -287,9 +287,8 @@ Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_op RETURN_NOT_OK(client->GetFlightInfo(call_options, descriptor, &plan)); // Read the streams in parallel - std::shared_ptr schema; ipc::DictionaryMemo dict_memo; - RETURN_NOT_OK(plan->GetSchema(&dict_memo, &schema)); + ARROW_ASSIGN_OR_RAISE(auto schema, plan->GetSchema(&dict_memo)); int64_t start_total_records = stats->total_records; @@ -451,7 +450,8 @@ int main(int argc, char** argv) { std::cout << "Using standalone Unix server" << std::endl; } std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl; - ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location)); + ABORT_NOT_OK( + arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix).Value(&location)); } else { if (FLAGS_server_host == "") { FLAGS_server_host = "localhost"; @@ -482,11 +482,13 @@ int main(int argc, char** argv) { std::cout << "Server host: " << FLAGS_server_host << std::endl << "Server port: " << FLAGS_server_port << std::endl; if (FLAGS_cert_file.empty()) { - ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, - FLAGS_server_port, &location)); + ABORT_NOT_OK( + arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_server_port) + .Value(&location)); } else { - ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host, - FLAGS_server_port, &location)); + ABORT_NOT_OK( + arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_server_port) + .Value(&location)); options.disable_server_verification = true; } } diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index e4babc03352e6..f7b731f01cbf5 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -78,9 +78,8 @@ TEST(FlightTypes, FlightDescriptorToFromProto) { // ARROW-6017: we should be able to construct locations for unknown // schemes TEST(FlightTypes, LocationUnknownScheme) { - Location location; - ASSERT_OK(Location::Parse("s3://test", &location)); - ASSERT_OK(Location::Parse("https://example.com/foo", &location)); + ASSERT_OK(Location::Parse("s3://test")); + ASSERT_OK(Location::Parse("https://example.com/foo")); } TEST(FlightTypes, RoundTripTypes) { @@ -105,10 +104,9 @@ TEST(FlightTypes, RoundTripTypes) { std::shared_ptr schema = arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()), field("d", int64())}); - Location location1, location2, location3; - ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1)); - ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2)); - ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3)); + ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("localhost", 10010)); + ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTls("localhost", 10010)); + ASSERT_OK_AND_ASSIGN(auto location3, Location::ForGrpcUnix("/tmp/test.sock")); std::vector endpoints{FlightEndpoint{ticket, {location1, location2}}, FlightEndpoint{ticket, {location3}}}; ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data)); @@ -177,6 +175,43 @@ TEST(FlightTypes, RoundtripStatus) { ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); } +TEST(FlightTypes, LocationConstruction) { + ASSERT_RAISES(Invalid, Location::Parse("This is not an URI").status()); + ASSERT_RAISES(Invalid, Location::ForGrpcTcp("This is not a hostname", 12345).status()); + ASSERT_RAISES(Invalid, Location::ForGrpcTls("This is not a hostname", 12345).status()); + ASSERT_RAISES(Invalid, Location::ForGrpcUnix("This is not a filename").status()); + + ASSERT_OK_AND_ASSIGN(auto location, Location::Parse("s3://test")); + ASSERT_EQ(location.ToString(), "s3://test"); + ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", 12345)); + ASSERT_EQ(location.ToString(), "grpc+tcp://localhost:12345"); + ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTls("localhost", 12345)); + ASSERT_EQ(location.ToString(), "grpc+tls://localhost:12345"); + ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcUnix("/tmp/test.sock")); + ASSERT_EQ(location.ToString(), "grpc+unix:///tmp/test.sock"); +} + +ARROW_SUPPRESS_DEPRECATION_WARNING +TEST(FlightTypes, DeprecatedLocationConstruction) { + Location location; + ASSERT_RAISES(Invalid, Location::Parse("This is not an URI", &location)); + ASSERT_RAISES(Invalid, + Location::ForGrpcTcp("This is not a hostname", 12345, &location)); + ASSERT_RAISES(Invalid, + Location::ForGrpcTls("This is not a hostname", 12345, &location)); + ASSERT_RAISES(Invalid, Location::ForGrpcUnix("This is not a filename", &location)); + + ASSERT_OK(Location::Parse("s3://test", &location)); + ASSERT_EQ(location.ToString(), "s3://test"); + ASSERT_OK(Location::ForGrpcTcp("localhost", 12345, &location)); + ASSERT_EQ(location.ToString(), "grpc+tcp://localhost:12345"); + ASSERT_OK(Location::ForGrpcTls("localhost", 12345, &location)); + ASSERT_EQ(location.ToString(), "grpc+tls://localhost:12345"); + ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location)); + ASSERT_EQ(location.ToString(), "grpc+unix:///tmp/test.sock"); +} +ARROW_UNSUPPRESS_DEPRECATION_WARNING + // ---------------------------------------------------------------------- // Cookie authentication/middleware diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index f4207a34f154d..9bd0a8f49a4ea 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -122,10 +122,8 @@ TEST(TestFlight, ConnectUri) { std::string uri = ss.str(); std::unique_ptr client; - Location location1; - Location location2; - ASSERT_OK(Location::Parse(uri, &location1)); - ASSERT_OK(Location::Parse(uri, &location2)); + ASSERT_OK_AND_ASSIGN(auto location1, Location::Parse(uri)); + ASSERT_OK_AND_ASSIGN(auto location2, Location::Parse(uri)); ASSERT_OK(FlightClient::Connect(location1, &client)); ASSERT_OK(client->Close()); ASSERT_OK(FlightClient::Connect(location2, &client)); @@ -143,10 +141,8 @@ TEST(TestFlight, ConnectUriUnix) { std::string uri = ss.str(); std::unique_ptr client; - Location location1; - Location location2; - ASSERT_OK(Location::Parse(uri, &location1)); - ASSERT_OK(Location::Parse(uri, &location2)); + ASSERT_OK_AND_ASSIGN(auto location1, Location::Parse(uri)); + ASSERT_OK_AND_ASSIGN(auto location2, Location::Parse(uri)); ASSERT_OK(FlightClient::Connect(location1, &client)); ASSERT_OK(client->Close()); ASSERT_OK(FlightClient::Connect(location2, &client)); @@ -156,15 +152,14 @@ TEST(TestFlight, ConnectUriUnix) { // CI environments don't have an IPv6 interface configured TEST(TestFlight, DISABLED_IpV6Port) { - Location location, location2; std::unique_ptr server = ExampleTestServer(); - ASSERT_OK(Location::ForGrpcTcp("[::1]", 0, &location)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("[::1]", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); ASSERT_GT(server->port(), 0); - ASSERT_OK(Location::ForGrpcTcp("[::1]", server->port(), &location2)); + ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTcp("[::1]", server->port())); std::unique_ptr client; ASSERT_OK(FlightClient::Connect(location2, &client)); std::unique_ptr listing; @@ -179,8 +174,7 @@ class TestFlightClient : public ::testing::Test { void SetUp() { server_ = ExampleTestServer(); - Location location; - ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); FlightServerOptions options(location); ASSERT_OK(server_->Init(options)); @@ -193,8 +187,8 @@ class TestFlightClient : public ::testing::Test { } Status ConnectClient() { - Location location; - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location)); + ARROW_ASSIGN_OR_RAISE(auto location, + Location::ForGrpcTcp("localhost", server_->port())); return FlightClient::Connect(location, &client_); } @@ -208,9 +202,8 @@ class TestFlightClient : public ::testing::Test { ASSERT_OK(client_->GetFlightInfo(descr, &info)); check_endpoints(info->endpoints()); - std::shared_ptr schema; ipc::DictionaryMemo dict_memo; - ASSERT_OK(info->GetSchema(&dict_memo, &schema)); + ASSERT_OK_AND_ASSIGN(auto schema, info->GetSchema(&dict_memo)); AssertSchemaEqual(*expected_schema, *schema); // By convention, fetch the first endpoint @@ -232,7 +225,7 @@ class TestFlightClient : public ::testing::Test { FlightStreamChunk chunk; std::shared_ptr batch; for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(chunk, stream->Next()); ASSERT_OK(reader->ReadNext(&batch)); ASSERT_NE(nullptr, chunk.data); ASSERT_NE(nullptr, batch); @@ -255,7 +248,7 @@ class TestFlightClient : public ::testing::Test { } // Stream exhausted - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(chunk, stream->Next()); ASSERT_OK(reader->ReadNext(&batch)); ASSERT_EQ(nullptr, chunk.data); ASSERT_EQ(nullptr, batch); @@ -353,14 +346,13 @@ class TestTls : public ::testing::Test { server_.reset(new TlsTestServer); - Location location; - ASSERT_OK(Location::ForGrpcTls("localhost", 0, &location)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTls("localhost", 0)); FlightServerOptions options(location); ASSERT_RAISES(UnknownError, server_->Init(options)); ASSERT_OK(ExampleTlsCertificates(&options.tls_certificates)); ASSERT_OK(server_->Init(options)); - ASSERT_OK(Location::ForGrpcTls("localhost", server_->port(), &location_)); + ASSERT_OK_AND_ASSIGN(location_, Location::ForGrpcTls("localhost", server_->port())); ASSERT_OK(ConnectClient()); } @@ -877,13 +869,13 @@ TEST_F(TestFlightClient, ListFlights) { std::unique_ptr info; for (const FlightInfo& flight : flights) { - ASSERT_OK(listing->Next(&info)); + ASSERT_OK_AND_ASSIGN(info, listing->Next()); AssertEqual(flight, *info); } - ASSERT_OK(listing->Next(&info)); + ASSERT_OK_AND_ASSIGN(info, listing->Next()); ASSERT_TRUE(info == nullptr); - ASSERT_OK(listing->Next(&info)); + ASSERT_OK_AND_ASSIGN(info, listing->Next()); ASSERT_TRUE(info == nullptr); } @@ -891,7 +883,7 @@ TEST_F(TestFlightClient, ListFlightsWithCriteria) { std::unique_ptr listing; ASSERT_OK(client_->ListFlights(FlightCallOptions(), {"foo"}, &listing)); std::unique_ptr info; - ASSERT_OK(listing->Next(&info)); + ASSERT_OK_AND_ASSIGN(info, listing->Next()); ASSERT_TRUE(info == nullptr); } @@ -908,13 +900,11 @@ TEST_F(TestFlightClient, GetFlightInfo) { TEST_F(TestFlightClient, GetSchema) { auto descr = FlightDescriptor::Path({"examples", "ints"}); - std::unique_ptr schema_result; - std::shared_ptr schema; ipc::DictionaryMemo dict_memo; - ASSERT_OK(client_->GetSchema(descr, &schema_result)); + ASSERT_OK_AND_ASSIGN(auto schema_result, client_->GetSchema(descr)); ASSERT_NE(schema_result, nullptr); - ASSERT_OK(schema_result->GetSchema(&dict_memo, &schema)); + ASSERT_OK(schema_result->GetSchema(&dict_memo)); } TEST_F(TestFlightClient, GetFlightInfoNotFound) { @@ -948,20 +938,20 @@ TEST_F(TestFlightClient, DoAction) { ASSERT_OK(client_->DoAction(action, &stream)); for (int i = 0; i < 3; ++i) { - ASSERT_OK(stream->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, stream->Next()); std::string expected = action1_value + "-part" + std::to_string(i); ASSERT_EQ(expected, result->body->ToString()); } // stream consumed - ASSERT_OK(stream->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, stream->Next()); ASSERT_EQ(nullptr, result); // Run action2, no results action.type = "action2"; ASSERT_OK(client_->DoAction(action, &stream)); - ASSERT_OK(stream->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, stream->Next()); ASSERT_EQ(nullptr, result); } @@ -979,14 +969,12 @@ TEST_F(TestFlightClient, GenericOptions) { auto options = FlightClientOptions::Defaults(); // Set a very low limit at the gRPC layer to fail all calls options.generic_options.emplace_back(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, 4); - Location location; - ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", server_->port())); ASSERT_OK(FlightClient::Connect(location, options, &client)); auto descr = FlightDescriptor::Path({"examples", "ints"}); - std::unique_ptr schema_result; std::shared_ptr schema; ipc::DictionaryMemo dict_memo; - auto status = client->GetSchema(descr, &schema_result); + auto status = client->GetSchema(descr).status(); ASSERT_RAISES(Invalid, status); ASSERT_THAT(status.message(), ::testing::HasSubstr("resource exhausted")); } @@ -994,8 +982,7 @@ TEST_F(TestFlightClient, GenericOptions) { TEST_F(TestFlightClient, TimeoutFires) { // Server does not exist on this port, so call should fail std::unique_ptr client; - Location location; - ASSERT_OK(Location::ForGrpcTcp("localhost", 30001, &location)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 30001)); ASSERT_OK(FlightClient::Connect(location, &client)); FlightCallOptions options; options.timeout = TimeoutDuration{0.2}; @@ -1136,12 +1123,12 @@ TEST_F(TestAuthHandler, CheckPeerIdentity) { ASSERT_NE(results, nullptr); std::unique_ptr result; - ASSERT_OK(results->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, results->Next()); ASSERT_NE(result, nullptr); // Action returns the peer identity as the result. ASSERT_EQ(result->body->ToString(), "user"); - ASSERT_OK(results->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, results->Next()); ASSERT_NE(result, nullptr); // Action returns the peer address as the result. #ifndef _WIN32 @@ -1245,7 +1232,7 @@ TEST_F(TestBasicAuthHandler, CheckPeerIdentity) { ASSERT_NE(results, nullptr); std::unique_ptr result; - ASSERT_OK(results->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, results->Next()); ASSERT_NE(result, nullptr); // Action returns the peer identity as the result. ASSERT_EQ(result->body->ToString(), "user"); @@ -1262,7 +1249,7 @@ TEST_F(TestTls, DoAction) { ASSERT_NE(results, nullptr); std::unique_ptr result; - ASSERT_OK(results->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, results->Next()); ASSERT_NE(result, nullptr); ASSERT_EQ(result->body->ToString(), "Hello, world!"); } @@ -1287,7 +1274,7 @@ TEST_F(TestTls, DisableServerVerification) { ASSERT_NE(results, nullptr); std::unique_ptr result; - ASSERT_OK(results->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, results->Next()); ASSERT_NE(result, nullptr); ASSERT_EQ(result->body->ToString(), "Hello, world!"); } @@ -1352,8 +1339,7 @@ TEST_F(TestCountingServerMiddleware, Count) { ASSERT_EQ(1, request_counter_->failed_); while (true) { - FlightStreamChunk chunk; - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(FlightStreamChunk chunk, stream->Next()); if (chunk.data == nullptr) { break; } @@ -1375,7 +1361,7 @@ TEST_F(TestPropagatingMiddleware, Propagate) { action.body = Buffer::FromString("action1-content"); ASSERT_OK(client_->DoAction(action, &stream)); - ASSERT_OK(stream->Next(&result)); + ASSERT_OK_AND_ASSIGN(result, stream->Next()); ASSERT_EQ("trace-id", result->body->ToString()); ValidateStatus(Status::OK(), FlightMethod::DoAction); } @@ -1403,8 +1389,7 @@ TEST_F(TestPropagatingMiddleware, GetFlightInfo) { TEST_F(TestPropagatingMiddleware, GetSchema) { client_middleware_->Reset(); auto descr = FlightDescriptor::Path({"examples", "ints"}); - std::unique_ptr result; - const Status status = client_->GetSchema(descr, &result); + const Status status = client_->GetSchema(descr).status(); ASSERT_RAISES(NotImplemented, status); ValidateStatus(status, FlightMethod::GetSchema); } @@ -1445,19 +1430,18 @@ TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); } class ForeverFlightListing : public FlightListing { - Status Next(std::unique_ptr* info) override { + arrow::Result> Next() override { std::this_thread::sleep_for(std::chrono::milliseconds(100)); - *info = arrow::internal::make_unique(ExampleFlightInfo()[0]); - return Status::OK(); + return arrow::internal::make_unique(ExampleFlightInfo()[0]); } }; class ForeverResultStream : public ResultStream { - Status Next(std::unique_ptr* result) override { + arrow::Result> Next() override { std::this_thread::sleep_for(std::chrono::milliseconds(100)); - *result = arrow::internal::make_unique(); - (*result)->body = Buffer::FromString("foo"); - return Status::OK(); + auto result = arrow::internal::make_unique(); + result->body = Buffer::FromString("foo"); + return result; } }; @@ -1471,10 +1455,12 @@ class ForeverDataStream : public FlightDataStream { &payload->ipc_message); } - Status Next(FlightPayload* payload) override { + arrow::Result Next() override { auto batch = RecordBatch::Make(schema_, 0, ArrayVector{}); - return ipc::GetRecordBatchPayload(*batch, ipc::IpcWriteOptions::Defaults(), - &payload->ipc_message); + FlightPayload payload; + RETURN_NOT_OK(ipc::GetRecordBatchPayload(*batch, ipc::IpcWriteOptions::Defaults(), + &payload.ipc_message)); + return payload; } private: @@ -1561,13 +1547,12 @@ TEST_F(TestCancel, DoGet) { stop_source.RequestStop(Status::Cancelled("StopSource")); std::unique_ptr stream; ASSERT_OK(client_->DoGet(options, {}, &stream)); - std::shared_ptr
table; EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"), - stream->ReadAll(&table)); + stream->ToTable()); ASSERT_OK(client_->DoGet({}, &stream)); EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"), - stream->ReadAll(&table, options.stop_token)); + stream->ToTable(options.stop_token)); } TEST_F(TestCancel, DoExchange) { @@ -1580,14 +1565,13 @@ TEST_F(TestCancel, DoExchange) { std::unique_ptr stream; ASSERT_OK( client_->DoExchange(options, FlightDescriptor::Command(""), &writer, &stream)); - std::shared_ptr
table; EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"), - stream->ReadAll(&table)); + stream->ToTable()); ARROW_UNUSED(writer->Close()); ASSERT_OK(client_->DoExchange(FlightDescriptor::Command(""), &writer, &stream)); EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"), - stream->ReadAll(&table, options.stop_token)); + stream->ToTable(options.stop_token)); ARROW_UNUSED(writer->Close()); } diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 1e08f47b579bd..48788198f0f3e 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -53,7 +53,7 @@ Status CheckActionResults(FlightClient* client, const Action& action, RETURN_NOT_OK(client->DoAction(action, &stream)); std::unique_ptr result; for (const std::string& expected : results) { - RETURN_NOT_OK(stream->Next(&result)); + ARROW_ASSIGN_OR_RAISE(result, stream->Next()); if (!result) { return Status::Invalid("Action result stream ended early"); } @@ -62,7 +62,7 @@ Status CheckActionResults(FlightClient* client, const Action& action, return Status::Invalid("Got wrong result; expected", expected, "but got", actual); } } - RETURN_NOT_OK(stream->Next(&result)); + ARROW_ASSIGN_OR_RAISE(result, stream->Next()); if (result) { return Status::Invalid("Action result stream had too many entries"); } @@ -196,9 +196,8 @@ class MiddlewareServer : public FlightServerBase { descriptor.cmd == "success") { // Don't fail std::shared_ptr schema = arrow::schema({}); - Location location; // Return a fake location - the test doesn't read it - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 10010, &location)); + ARROW_ASSIGN_OR_RAISE(auto location, Location::ForGrpcTcp("localhost", 10010)); std::vector endpoints{FlightEndpoint{{"foo"}, {location}}}; ARROW_ASSIGN_OR_RAISE(auto info, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); diff --git a/cpp/src/arrow/flight/integration_tests/test_integration_client.cc b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc index 366284389f104..971ea2b8f3596 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration_client.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc @@ -102,8 +102,7 @@ Status ConsumeFlightLocation( int counter = 0; const int expected = static_cast(retrieved_data.size()); for (const auto& original_batch : retrieved_data) { - FlightStreamChunk chunk; - RETURN_NOT_OK(stream->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, stream->Next()); if (chunk.data == nullptr) { return Status::Invalid("Got fewer batches than expected, received so far: ", counter, " expected ", expected); @@ -125,8 +124,7 @@ Status ConsumeFlightLocation( counter++; } - FlightStreamChunk chunk; - RETURN_NOT_OK(stream->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, stream->Next()); if (chunk.data != nullptr) { return Status::Invalid("Got more batches than the expected ", expected); } @@ -177,7 +175,7 @@ class IntegrationTestScenario : public Scenario { std::shared_ptr schema; ipc::DictionaryMemo dict_memo; - ABORT_NOT_OK(info->GetSchema(&dict_memo, &schema)); + ABORT_NOT_OK(info->GetSchema(&dict_memo).Value(&schema)); if (info->endpoints().size() == 0) { std::cerr << "No endpoints returned from Flight server." << std::endl; @@ -213,8 +211,8 @@ arrow::Status RunScenario(arrow::flight::integration_tests::Scenario* scenario) std::unique_ptr client; RETURN_NOT_OK(scenario->MakeClient(&options)); - arrow::flight::Location location; - RETURN_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location)); + ARROW_ASSIGN_OR_RAISE(auto location, + arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port)); RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client)); return scenario->RunClient(std::move(client)); } diff --git a/cpp/src/arrow/flight/integration_tests/test_integration_server.cc b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc index dad76c6914d05..9127f55fe1424 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration_server.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc @@ -88,8 +88,8 @@ class FlightIntegrationTestServer : public FlightServerBase { } auto flight = data->second; - Location server_location; - RETURN_NOT_OK(Location::ForGrpcTcp("127.0.0.1", port(), &server_location)); + ARROW_ASSIGN_OR_RAISE(auto server_location, + Location::ForGrpcTcp("127.0.0.1", port())); FlightEndpoint endpoint1({{request.path[0]}, {server_location}}); FlightInfo::Data flight_data; @@ -142,7 +142,7 @@ class FlightIntegrationTestServer : public FlightServerBase { ARROW_ASSIGN_OR_RAISE(dataset.schema, reader->GetSchema()); arrow::flight::FlightStreamChunk chunk; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (chunk.data == nullptr) break; RETURN_NOT_OK(chunk.data->ValidateFull()); dataset.chunks.push_back(chunk.data); @@ -196,7 +196,8 @@ int main(int argc, char** argv) { std::make_shared(); } arrow::flight::Location location; - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); + ARROW_CHECK_OK( + arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port).Value(&location)); arrow::flight::FlightServerOptions options(location); ARROW_CHECK_OK(scenario->MakeServer(&g_server, &options)); diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index ae2b2a485cbeb..9b25afbbca11d 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -89,10 +89,11 @@ class PerfDataStream : public FlightDataStream { return ipc::GetSchemaPayload(*schema_, ipc_options_, mapper_, &payload->ipc_message); } - Status Next(FlightPayload* payload) override { + arrow::Result Next() override { + FlightPayload payload; if (records_sent_ >= total_records_) { // Signal that iteration is over - payload->ipc_message.metadata = nullptr; + payload.ipc_message.metadata = nullptr; return Status::OK(); } @@ -114,7 +115,8 @@ class PerfDataStream : public FlightDataStream { } else { records_sent_ += batch_length_; } - return ipc::GetRecordBatchPayload(*batch, ipc_options_, &payload->ipc_message); + RETURN_NOT_OK(ipc::GetRecordBatchPayload(*batch, ipc_options_, &payload.ipc_message)); + return payload; } private: @@ -202,7 +204,7 @@ class FlightPerfServer : public FlightServerBase { std::unique_ptr writer) override { FlightStreamChunk chunk; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (!chunk.data) break; if (chunk.app_metadata) { RETURN_NOT_OK(writer->WriteMetadata(*chunk.app_metadata)); @@ -248,25 +250,26 @@ int main(int argc, char** argv) { if (FLAGS_server_unix.empty()) { if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) { + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port) + .Value(&bind_location)); ARROW_CHECK_OK( - arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port, &bind_location)); - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls( - FLAGS_server_host, FLAGS_port, &connect_location)); + arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_port) + .Value(&connect_location)); } else { std::cerr << "If providing TLS cert/key, must provide both" << std::endl; return EXIT_FAILURE; } } else { - ARROW_CHECK_OK( - arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location)); - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port, - &connect_location)); + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port) + .Value(&bind_location)); + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port) + .Value(&connect_location)); } } else { ARROW_CHECK_OK( - arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &bind_location)); - ARROW_CHECK_OK( - arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &connect_location)); + arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix).Value(&bind_location)); + ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix) + .Value(&connect_location)); } } else { std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; diff --git a/cpp/src/arrow/flight/serialization_internal.cc b/cpp/src/arrow/flight/serialization_internal.cc index bbffc643466e5..fa21a934bd199 100644 --- a/cpp/src/arrow/flight/serialization_internal.cc +++ b/cpp/src/arrow/flight/serialization_internal.cc @@ -89,7 +89,7 @@ Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria) { // Location Status FromProto(const pb::Location& pb_location, Location* location) { - return Location::Parse(pb_location.uri(), location); + return Location::Parse(pb_location.uri()).Value(location); } Status ToProto(const Location& location, pb::Location* pb_location) { diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index dff8d075610f4..9907b08950faa 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -409,7 +409,13 @@ Status RecordBatchStream::GetSchemaPayload(FlightPayload* payload) { return impl_->GetSchemaPayload(payload); } -Status RecordBatchStream::Next(FlightPayload* payload) { return impl_->Next(payload); } +arrow::Result RecordBatchStream::Next() { + FlightPayload payload; + RETURN_NOT_OK(impl_->Next(&payload)); + return payload; +} + +Status RecordBatchStream::Next(FlightPayload* payload) { return Next().Value(payload); } } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index df17f2cc197d0..f0d0f726709b0 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -55,7 +55,10 @@ class ARROW_FLIGHT_EXPORT FlightDataStream { // When the stream is completed, the last payload written will have null // metadata - virtual Status Next(FlightPayload* payload) = 0; + virtual arrow::Result Next() = 0; + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") + Status Next(FlightPayload* payload) { return Next().Value(payload); } }; /// \brief A basic implementation of FlightDataStream that will provide @@ -71,7 +74,11 @@ class ARROW_FLIGHT_EXPORT RecordBatchStream : public FlightDataStream { std::shared_ptr schema() override; Status GetSchemaPayload(FlightPayload* payload) override; - Status Next(FlightPayload* payload) override; + + arrow::Result Next() override; + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") + Status Next(FlightPayload* payload); private: class RecordBatchStreamImpl; diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index 199eebb66b3bf..93cb57b902885 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -263,8 +263,7 @@ arrow::Result> FlightSqlClient::Prepare( ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); - std::unique_ptr result; - ARROW_RETURN_NOT_OK(results->Next(&result)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr result, results->Next()); google::protobuf::Any prepared_result; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc index dde364f64e3a4..561157a7158f2 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc @@ -88,9 +88,8 @@ std::string PrepareQueryForGetTables(const GetTables& command) { } Status SetParametersOnSQLiteStatement(sqlite3_stmt* stmt, FlightMessageReader* reader) { - FlightStreamChunk chunk; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next()); std::shared_ptr& record_batch = chunk.data; if (record_batch == nullptr) break; diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index 47a03bd71d010..461adb7d7572f 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -138,8 +138,7 @@ class TestFlightSqlServer : public ::testing::Test { ARROW_ASSIGN_OR_RAISE(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ARROW_RETURN_NOT_OK(stream->ReadAll(&table)); + ARROW_ASSIGN_OR_RAISE(auto table, stream->ToTable()); const std::shared_ptr& result_array = table->column(0)->chunk(0); ARROW_ASSIGN_OR_RAISE(auto count_scalar, result_array->GetScalar(0)); @@ -160,8 +159,7 @@ class TestFlightSqlServer : public ::testing::Test { std::string uri = ss.str(); std::unique_ptr client; - Location location; - ASSERT_OK(Location::Parse(uri, &location)); + ASSERT_OK_AND_ASSIGN(auto location, Location::Parse(uri)); ASSERT_OK(FlightClient::Connect(location, &client)); sql_client.reset(new FlightSqlClient(std::move(client))); @@ -184,8 +182,7 @@ class TestFlightSqlServer : public ::testing::Test { std::mutex server_ready_m; void RunServer() { - arrow::flight::Location location; - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("localhost", port, &location)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", port)); arrow::flight::FlightServerOptions options(location); ARROW_CHECK_OK(example::SQLiteFlightSqlServer::Create().Value(&server)); @@ -206,8 +203,7 @@ TEST_F(TestFlightSqlServer, TestCommandStatementQuery) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const std::shared_ptr& expected_schema = arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), @@ -241,8 +237,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTables) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); ASSERT_OK_AND_ASSIGN(auto catalog_name, MakeArrayOfNull(utf8(), 3)) ASSERT_OK_AND_ASSIGN(auto schema_name, MakeArrayOfNull(utf8(), 3)) @@ -273,8 +268,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTablesWithTableFilter) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const auto catalog_name = ArrayFromJSON(utf8(), R"([null])"); const auto schema_name = ArrayFromJSON(utf8(), R"([null])"); @@ -303,8 +297,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTablesWithTableTypesFilter) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); AssertSchemaEqual(SqlSchema::GetTablesSchema(), table->schema()); @@ -327,8 +320,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTablesWithUnexistenceTableTypeFilter) ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const auto catalog_name = ArrayFromJSON(utf8(), R"([null, null, null])"); const auto schema_name = ArrayFromJSON(utf8(), R"([null, null, null])"); @@ -358,8 +350,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTablesWithIncludedSchemas) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const auto catalog_name = ArrayFromJSON(utf8(), R"([null])"); const auto schema_name = ArrayFromJSON(utf8(), R"([null])"); @@ -388,8 +379,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetCatalogs) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const std::shared_ptr& expected_schema = SqlSchema::GetCatalogsSchema(); @@ -407,8 +397,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetDbSchemas) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const std::shared_ptr& expected_schema = SqlSchema::GetDbSchemasSchema(); @@ -422,8 +411,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetTableTypes) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const auto table_type = ArrayFromJSON(utf8(), R"(["table"])"); @@ -462,8 +450,7 @@ TEST_F(TestFlightSqlServer, TestCommandPreparedStatementQuery) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const std::shared_ptr& expected_schema = arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), @@ -516,8 +503,7 @@ TEST_F(TestFlightSqlServer, TestCommandPreparedStatementQueryWithParameterBindin ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const std::shared_ptr& expected_schema = arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()), @@ -603,8 +589,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetPrimaryKeys) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const auto catalog_name = ArrayFromJSON(utf8(), R"([null])"); const auto schema_name = ArrayFromJSON(utf8(), R"([null])"); @@ -628,8 +613,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetImportedKeys) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const auto pk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); const auto pk_schema_name = ArrayFromJSON(utf8(), R"([null])"); @@ -661,8 +645,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetExportedKeys) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const auto pk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); const auto pk_schema_name = ArrayFromJSON(utf8(), R"([null])"); @@ -696,8 +679,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetCrossReference) { ASSERT_OK_AND_ASSIGN(auto stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); const auto pk_catalog_name = ArrayFromJSON(utf8(), R"([null])"); const auto pk_schema_name = ArrayFromJSON(utf8(), R"([null])"); @@ -734,8 +716,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetSqlInfo) { sql_client->GetSqlInfo(call_options, sql_info_ids)); ASSERT_OK_AND_ASSIGN( auto reader, sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); - std::shared_ptr
results; - ASSERT_OK(reader->ReadAll(&results)); + ASSERT_OK_AND_ASSIGN(auto results, reader->ToTable()); ASSERT_EQ(2, results->num_columns()); ASSERT_EQ(sql_info_ids.size(), results->num_rows()); const auto& col_name = results->column(0); diff --git a/cpp/src/arrow/flight/sql/test_app_cli.cc b/cpp/src/arrow/flight/sql/test_app_cli.cc index 43c37bee2fe86..63924cc1c91d4 100644 --- a/cpp/src/arrow/flight/sql/test_app_cli.cc +++ b/cpp/src/arrow/flight/sql/test_app_cli.cc @@ -71,11 +71,10 @@ Status PrintResultsForEndpoint(FlightSqlClient& client, std::cout << "Results:" << std::endl; - FlightStreamChunk chunk; int64_t num_rows = 0; while (true) { - ARROW_RETURN_NOT_OK(stream->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, stream->Next()); if (chunk.data == nullptr) { break; } @@ -103,8 +102,7 @@ Status PrintResults(FlightSqlClient& client, const FlightCallOptions& call_optio Status RunMain() { std::unique_ptr client; - Location location; - ARROW_RETURN_NOT_OK(Location::ForGrpcTcp(FLAGS_host, FLAGS_port, &location)); + ARROW_ASSIGN_OR_RAISE(auto location, Location::ForGrpcTcp(FLAGS_host, FLAGS_port)); ARROW_RETURN_NOT_OK(FlightClient::Connect(location, &client)); FlightCallOptions call_options; diff --git a/cpp/src/arrow/flight/sql/test_server_cli.cc b/cpp/src/arrow/flight/sql/test_server_cli.cc index e0ba5340e8d94..e847b1137aecd 100644 --- a/cpp/src/arrow/flight/sql/test_server_cli.cc +++ b/cpp/src/arrow/flight/sql/test_server_cli.cc @@ -31,8 +31,8 @@ DEFINE_int32(port, 31337, "Server port to listen on"); arrow::Status RunMain() { - arrow::flight::Location location; - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); + ARROW_ASSIGN_OR_RAISE(auto location, + arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port)); arrow::flight::FlightServerOptions options(location); std::shared_ptr server; diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index 41bd092f230e2..5ead99f94ba16 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -142,9 +142,8 @@ void DataTest::CheckDoGet( ASSERT_OK(client_->GetFlightInfo(descr, &info)); check_endpoints(info->endpoints()); - std::shared_ptr schema; ipc::DictionaryMemo dict_memo; - ASSERT_OK(info->GetSchema(&dict_memo, &schema)); + ASSERT_OK_AND_ASSIGN(auto schema, info->GetSchema(&dict_memo)); AssertSchemaEqual(*expected_schema, *schema); // By convention, fetch the first endpoint @@ -163,10 +162,9 @@ void DataTest::CheckDoGet(const Ticket& ticket, ASSERT_OK(client_->DoGet(ticket, &stream2)); ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2))); - FlightStreamChunk chunk; std::shared_ptr batch; for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, stream->Next()); ASSERT_OK(reader->ReadNext(&batch)); ASSERT_NE(nullptr, chunk.data); ASSERT_NE(nullptr, batch); @@ -189,7 +187,7 @@ void DataTest::CheckDoGet(const Ticket& ticket, } // Stream exhausted - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, stream->Next()); ASSERT_OK(reader->ReadNext(&batch)); ASSERT_EQ(nullptr, chunk.data); ASSERT_EQ(nullptr, batch); @@ -253,7 +251,7 @@ void DataTest::TestOverflowServerBatch() { FlightStreamChunk chunk; EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), - stream->Next(&chunk)); + stream->Next()); } { // DoExchange: check for overflow on large batch from server @@ -264,7 +262,7 @@ void DataTest::TestOverflowServerBatch() { RecordBatchVector batches; EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), - reader->ReadAll(&batches)); + reader->ToRecordBatches().Value(&batches)); ARROW_UNUSED(writer->Close()); } } @@ -308,15 +306,14 @@ void DataTest::TestDoExchange() { ASSERT_OK(writer->WriteRecordBatch(*batch)); } ASSERT_OK(writer->DoneWriting()); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_NE(nullptr, chunk.app_metadata); ASSERT_EQ(nullptr, chunk.data); ASSERT_EQ("1", chunk.app_metadata->ToString()); ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema()); AssertSchemaEqual(schema, server_schema); for (const auto& batch : batches) { - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(chunk, reader->Next()); ASSERT_BATCHES_EQUAL(*batch, *chunk.data); } ASSERT_OK(writer->Close()); @@ -329,8 +326,7 @@ void DataTest::TestDoExchangeNoData() { std::unique_ptr writer; ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); ASSERT_OK(writer->DoneWriting()); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_EQ(nullptr, chunk.data); ASSERT_NE(nullptr, chunk.app_metadata); ASSERT_EQ("0", chunk.app_metadata->ToString()); @@ -347,8 +343,7 @@ void DataTest::TestDoExchangeWriteOnlySchema() { ASSERT_OK(writer->Begin(schema)); ASSERT_OK(writer->WriteMetadata(Buffer::FromString("foo"))); ASSERT_OK(writer->DoneWriting()); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_EQ(nullptr, chunk.data); ASSERT_NE(nullptr, chunk.app_metadata); ASSERT_EQ("0", chunk.app_metadata->ToString()); @@ -364,13 +359,12 @@ void DataTest::TestDoExchangeGet() { AssertSchemaEqual(*ExampleIntSchema(), *server_schema); RecordBatchVector batches; ASSERT_OK(ExampleIntBatches(&batches)); - FlightStreamChunk chunk; for (const auto& batch : batches) { - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_NE(nullptr, chunk.data); AssertBatchesEqual(*batch, *chunk.data); } - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_EQ(nullptr, chunk.data); ASSERT_EQ(nullptr, chunk.app_metadata); ASSERT_OK(writer->Close()); @@ -388,11 +382,10 @@ void DataTest::TestDoExchangePut() { ASSERT_OK(writer->WriteRecordBatch(*batch)); } ASSERT_OK(writer->DoneWriting()); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_NE(nullptr, chunk.app_metadata); AssertBufferEqual(*chunk.app_metadata, "done"); - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(chunk, reader->Next()); ASSERT_EQ(nullptr, chunk.data); ASSERT_EQ(nullptr, chunk.app_metadata); ASSERT_OK(writer->Close()); @@ -405,11 +398,10 @@ void DataTest::TestDoExchangeEcho() { ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); ASSERT_OK(writer->Begin(ExampleIntSchema())); RecordBatchVector batches; - FlightStreamChunk chunk; ASSERT_OK(ExampleIntBatches(&batches)); for (const auto& batch : batches) { ASSERT_OK(writer->WriteRecordBatch(*batch)); - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_NE(nullptr, chunk.data); ASSERT_EQ(nullptr, chunk.app_metadata); AssertBatchesEqual(*batch, *chunk.data); @@ -417,7 +409,7 @@ void DataTest::TestDoExchangeEcho() { for (int i = 0; i < 10; i++) { const auto buf = Buffer::FromString(std::to_string(i)); ASSERT_OK(writer->WriteMetadata(buf)); - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_EQ(nullptr, chunk.data); ASSERT_NE(nullptr, chunk.app_metadata); AssertBufferEqual(*buf, *chunk.app_metadata); @@ -426,7 +418,7 @@ void DataTest::TestDoExchangeEcho() { for (const auto& batch : batches) { const auto buf = Buffer::FromString(std::to_string(index)); ASSERT_OK(writer->WriteWithMetadata(*batch, buf)); - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_NE(nullptr, chunk.data); ASSERT_NE(nullptr, chunk.app_metadata); AssertBatchesEqual(*batch, *chunk.data); @@ -434,7 +426,7 @@ void DataTest::TestDoExchangeEcho() { index++; } ASSERT_OK(writer->DoneWriting()); - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_EQ(nullptr, chunk.data); ASSERT_EQ(nullptr, chunk.app_metadata); ASSERT_OK(writer->Close()); @@ -467,12 +459,11 @@ void DataTest::TestDoExchangeTotal() { ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); ASSERT_OK(writer->Begin(schema)); auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1, a2}); - FlightStreamChunk chunk; ASSERT_OK(writer->WriteRecordBatch(*batch)); ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema()); AssertSchemaEqual(*schema, *server_schema); - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_NE(nullptr, chunk.data); auto expected1 = RecordBatch::Make( schema, /* num_rows */ 1, @@ -480,7 +471,7 @@ void DataTest::TestDoExchangeTotal() { AssertBatchesEqual(*expected1, *chunk.data); ASSERT_OK(writer->WriteRecordBatch(*batch)); - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(chunk, reader->Next()); ASSERT_NE(nullptr, chunk.data); auto expected2 = RecordBatch::Make( schema, /* num_rows */ 1, @@ -503,9 +494,8 @@ void DataTest::TestDoExchangeError() { } { ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); - FlightStreamChunk chunk; EXPECT_RAISES_WITH_MESSAGE_THAT( - NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next(&chunk)); + NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next()); ARROW_UNUSED(writer->Close()); } { @@ -532,14 +522,13 @@ void DataTest::TestDoExchangeConcurrency() { ASSERT_OK(writer->Begin(ExampleIntSchema())); std::thread reader_thread([&reader, &batches]() { - FlightStreamChunk chunk; for (size_t i = 0; i < batches.size(); i++) { - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_NE(nullptr, chunk.data); ASSERT_EQ(nullptr, chunk.app_metadata); AssertBatchesEqual(*batches[i], *chunk.data); } - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_EQ(nullptr, chunk.data); ASSERT_EQ(nullptr, chunk.app_metadata); }); @@ -611,7 +600,7 @@ class DoPutTestServer : public FlightServerBase { int counter = 0; FlightStreamChunk chunk; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (!chunk.data) break; if (counter % 2 == 1) { if (!chunk.app_metadata) { @@ -860,7 +849,7 @@ Status AppMetadataTestServer::DoPut(const ServerCallContext& context, FlightStreamChunk chunk; int counter = 0; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (chunk.data == nullptr) break; if (chunk.app_metadata == nullptr) { return Status::Invalid("Expected application metadata to be provided"); @@ -895,16 +884,15 @@ void AppMetadataTest::TestDoGet() { RecordBatchVector expected_batches; ASSERT_OK(ExampleIntBatches(&expected_batches)); - FlightStreamChunk chunk; auto num_batches = static_cast(expected_batches.size()); for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, stream->Next()); ASSERT_NE(nullptr, chunk.data); ASSERT_NE(nullptr, chunk.app_metadata); ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString()); } - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, stream->Next()); ASSERT_EQ(nullptr, chunk.data); } // Test dictionaries. This tests a corner case in the reader: @@ -919,16 +907,15 @@ void AppMetadataTest::TestDoGetDictionaries() { RecordBatchVector expected_batches; ASSERT_OK(ExampleDictBatches(&expected_batches)); - FlightStreamChunk chunk; auto num_batches = static_cast(expected_batches.size()); for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, stream->Next()); ASSERT_NE(nullptr, chunk.data); ASSERT_NE(nullptr, chunk.app_metadata); ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString()); } - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, stream->Next()); ASSERT_EQ(nullptr, chunk.data); } void AppMetadataTest::TestDoPut() { @@ -1019,10 +1006,9 @@ class IpcOptionsTestServer : public FlightServerBase { Status DoPut(const ServerCallContext& context, std::unique_ptr reader, std::unique_ptr writer) override { - FlightStreamChunk chunk; int counter = 0; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next()); if (chunk.data == nullptr) break; counter++; } @@ -1035,12 +1021,11 @@ class IpcOptionsTestServer : public FlightServerBase { Status DoExchange(const ServerCallContext& context, std::unique_ptr reader, std::unique_ptr writer) override { - FlightStreamChunk chunk; auto options = ipc::IpcWriteOptions::Defaults(); options.max_recursion_depth = 1; bool begun = false; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next()); if (!chunk.data && !chunk.app_metadata) { break; } @@ -1078,8 +1063,7 @@ void IpcOptionsTest::TestDoGetReadOptions() { options.read_options.max_recursion_depth = 1; std::unique_ptr stream; ASSERT_OK(client_->DoGet(options, ticket, &stream)); - FlightStreamChunk chunk; - ASSERT_RAISES(Invalid, stream->Next(&chunk)); + ASSERT_RAISES(Invalid, stream->Next()); } void IpcOptionsTest::TestDoPutWriteOptions() { // Call DoPut, but with a very low write nesting depth set to fail the call. @@ -1211,7 +1195,7 @@ class CudaTestServer : public FlightServerBase { Status DoPut(const ServerCallContext&, std::unique_ptr reader, std::unique_ptr writer) override { - RETURN_NOT_OK(reader->ReadAll(&batches_)); + RETURN_NOT_OK(reader->ToRecordBatches().Value(&batches_)); return Status::OK(); } @@ -1221,7 +1205,7 @@ class CudaTestServer : public FlightServerBase { FlightStreamChunk chunk; bool begun = false; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (!chunk.data) break; if (!begun) { begun = true; @@ -1290,8 +1274,7 @@ void CudaDataTest::TestDoGet() { size_t idx = 0; while (true) { - FlightStreamChunk chunk; - ASSERT_OK(stream->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, stream->Next()); if (!chunk.data) break; ASSERT_OK(CheckBuffersOnDevice(*chunk.data, *impl_->device)); @@ -1366,8 +1349,7 @@ void CudaDataTest::TestDoExchange() { ASSERT_OK(CheckBuffersOnDevice(*cuda_batch, *impl_->device)); ASSERT_OK(writer->WriteRecordBatch(*cuda_batch)); - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); + ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next()); ASSERT_OK(CheckBuffersOnDevice(*chunk.data, *impl_->device)); // Bounce record batch back to host memory diff --git a/cpp/src/arrow/flight/test_server.cc b/cpp/src/arrow/flight/test_server.cc index 2e5b10f840388..18bf2b4135990 100644 --- a/cpp/src/arrow/flight/test_server.cc +++ b/cpp/src/arrow/flight/test_server.cc @@ -42,9 +42,9 @@ int main(int argc, char** argv) { arrow::flight::Location location; if (FLAGS_unix.empty()) { - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); + location = *arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port); } else { - ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_unix, &location)); + location = *arrow::flight::Location::ForGrpcUnix(FLAGS_unix); } arrow::flight::FlightServerOptions options(location); diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 88bbf5977dafd..f5f1edf5f8fd4 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -234,8 +234,7 @@ class FlightTestServer : public FlightServerBase { Status DoPut(const ServerCallContext&, std::unique_ptr reader, std::unique_ptr writer) override { - RecordBatchVector batches; - return reader->ReadAll(&batches); + return reader->ToRecordBatches().status(); } Status DoExchange(const ServerCallContext& context, @@ -293,7 +292,7 @@ class FlightTestServer : public FlightServerBase { RETURN_NOT_OK(ExampleIntBatches(&batches)); FlightStreamChunk chunk; for (const auto& batch : batches) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (!chunk.data) { return Status::Invalid("Expected another batch"); } @@ -301,7 +300,7 @@ class FlightTestServer : public FlightServerBase { return Status::Invalid("Batch does not match"); } } - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (chunk.data || chunk.app_metadata) { return Status::Invalid("Too many batches"); } @@ -318,7 +317,7 @@ class FlightTestServer : public FlightServerBase { FlightStreamChunk chunk; int chunks = 0; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (!chunk.data && !chunk.app_metadata) { break; } @@ -360,7 +359,7 @@ class FlightTestServer : public FlightServerBase { std::vector> columns(schema->num_fields()); RETURN_NOT_OK(writer->Begin(schema)); while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (!chunk.data && !chunk.app_metadata) { break; } @@ -405,7 +404,7 @@ class FlightTestServer : public FlightServerBase { FlightStreamChunk chunk; bool begun = false; while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); if (!chunk.data && !chunk.app_metadata) { break; } @@ -506,13 +505,13 @@ Status NumberingStream::GetSchemaPayload(FlightPayload* payload) { return stream_->GetSchemaPayload(payload); } -Status NumberingStream::Next(FlightPayload* payload) { - RETURN_NOT_OK(stream_->Next(payload)); - if (payload && payload->ipc_message.type == ipc::MessageType::RECORD_BATCH) { - payload->app_metadata = Buffer::FromString(std::to_string(counter_)); +arrow::Result NumberingStream::Next() { + ARROW_ASSIGN_OR_RAISE(FlightPayload payload, stream_->Next()); + if (payload.ipc_message.type == ipc::MessageType::RECORD_BATCH) { + payload.app_metadata = Buffer::FromString(std::to_string(counter_)); counter_++; } - return Status::OK(); + return payload; } std::shared_ptr ExampleIntSchema() { @@ -556,16 +555,11 @@ std::shared_ptr ExampleLargeSchema() { } std::vector ExampleFlightInfo() { - Location location1; - Location location2; - Location location3; - Location location4; - Location location5; - ARROW_EXPECT_OK(Location::ForGrpcTcp("foo1.bar.com", 12345, &location1)); - ARROW_EXPECT_OK(Location::ForGrpcTcp("foo2.bar.com", 12345, &location2)); - ARROW_EXPECT_OK(Location::ForGrpcTcp("foo3.bar.com", 12345, &location3)); - ARROW_EXPECT_OK(Location::ForGrpcTcp("foo4.bar.com", 12345, &location4)); - ARROW_EXPECT_OK(Location::ForGrpcTcp("foo5.bar.com", 12345, &location5)); + Location location1 = *Location::ForGrpcTcp("foo1.bar.com", 12345); + Location location2 = *Location::ForGrpcTcp("foo2.bar.com", 12345); + Location location3 = *Location::ForGrpcTcp("foo3.bar.com", 12345); + Location location4 = *Location::ForGrpcTcp("foo4.bar.com", 12345); + Location location5 = *Location::ForGrpcTcp("foo5.bar.com", 12345); FlightInfo::Data flight1, flight2, flight3, flight4; diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index 385eb58fa1630..0b2f9c9e30738 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -54,11 +54,10 @@ namespace flight { // Helpers to compare values for equality inline void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) { - std::shared_ptr ex_schema, actual_schema; ipc::DictionaryMemo expected_memo; ipc::DictionaryMemo actual_memo; - ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema)); - ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema)); + ASSERT_OK_AND_ASSIGN(auto ex_schema, expected.GetSchema(&expected_memo)); + ASSERT_OK_AND_ASSIGN(auto actual_schema, actual.GetSchema(&actual_memo)); AssertSchemaEqual(*ex_schema, *actual_schema); ASSERT_EQ(expected.total_records(), actual.total_records()); @@ -113,10 +112,9 @@ Status MakeServer(const Location& location, std::unique_ptr* s FlightServerOptions server_options(location); RETURN_NOT_OK(make_server_options(&server_options)); RETURN_NOT_OK((*server)->Init(server_options)); - Location real_location; std::string uri = location.scheme() + "://localhost:" + std::to_string((*server)->port()); - RETURN_NOT_OK(Location::Parse(uri, &real_location)); + ARROW_ASSIGN_OR_RAISE(auto real_location, Location::Parse(uri)); FlightClientOptions client_options = FlightClientOptions::Defaults(); RETURN_NOT_OK(make_client_options(&client_options)); return FlightClient::Connect(real_location, client_options, client); @@ -130,8 +128,7 @@ Status MakeServer(std::unique_ptr* server, std::function make_server_options, std::function make_client_options, Args&&... server_args) { - Location location; - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 0, &location)); + ARROW_ASSIGN_OR_RAISE(auto location, Location::ForGrpcTcp("localhost", 0)); return MakeServer(location, server, client, std::move(make_server_options), std::move(make_client_options), std::forward(server_args)...); @@ -147,7 +144,7 @@ class ARROW_FLIGHT_EXPORT NumberingStream : public FlightDataStream { std::shared_ptr schema() override; Status GetSchemaPayload(FlightPayload* payload) override; - Status Next(FlightPayload* payload) override; + arrow::Result Next() override; private: int counter_; diff --git a/cpp/src/arrow/flight/transport.cc b/cpp/src/arrow/flight/transport.cc index 7a2429d0e39ca..2ccdf82bd7644 100644 --- a/cpp/src/arrow/flight/transport.cc +++ b/cpp/src/arrow/flight/transport.cc @@ -21,6 +21,7 @@ #include "arrow/flight/client_auth.h" #include "arrow/flight/transport_server.h" +#include "arrow/flight/types.h" #include "arrow/ipc/message.h" #include "arrow/result.h" #include "arrow/status.h" @@ -72,9 +73,8 @@ Status ClientTransport::GetFlightInfo(const FlightCallOptions& options, std::unique_ptr* info) { return Status::NotImplemented("GetFlightInfo for this transport"); } -Status ClientTransport::GetSchema(const FlightCallOptions& options, - const FlightDescriptor& descriptor, - std::unique_ptr* schema_result) { +arrow::Result> ClientTransport::GetSchema( + const FlightCallOptions& options, const FlightDescriptor& descriptor) { return Status::NotImplemented("GetSchema for this transport"); } Status ClientTransport::ListFlights(const FlightCallOptions& options, diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h index 085cfa99473f0..f02ab05157ac2 100644 --- a/cpp/src/arrow/flight/transport.h +++ b/cpp/src/arrow/flight/transport.h @@ -180,9 +180,8 @@ class ARROW_FLIGHT_EXPORT ClientTransport { virtual Status GetFlightInfo(const FlightCallOptions& options, const FlightDescriptor& descriptor, std::unique_ptr* info); - virtual Status GetSchema(const FlightCallOptions& options, - const FlightDescriptor& descriptor, - std::unique_ptr* schema_result); + virtual arrow::Result> GetSchema( + const FlightCallOptions& options, const FlightDescriptor& descriptor); virtual Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, std::unique_ptr* listing); virtual Status DoGet(const FlightCallOptions& options, const Ticket& ticket, diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc index d651e805a0c4e..1c0ac2d31faf5 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc @@ -43,6 +43,7 @@ #include "arrow/status.h" #include "arrow/util/base64.h" #include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" #include "arrow/util/uri.h" #include "arrow/flight/client.h" @@ -801,8 +802,8 @@ class GrpcClientImpl : public internal::ClientTransport { return Status::OK(); } - Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, - std::unique_ptr* schema_result) override { + arrow::Result> GetSchema( + const FlightCallOptions& options, const FlightDescriptor& descriptor) override { pb::FlightDescriptor pb_descriptor; pb::SchemaResult pb_response; @@ -816,8 +817,7 @@ class GrpcClientImpl : public internal::ClientTransport { std::string str; RETURN_NOT_OK(internal::FromProto(pb_response, &str)); - schema_result->reset(new SchemaResult(str)); - return Status::OK(); + return arrow::internal::make_unique(std::move(str)); } Status DoGet(const FlightCallOptions& options, const Ticket& ticket, diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc index 5a2901c1d549b..14daaa587654a 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc @@ -245,7 +245,7 @@ class GrpcServiceHandler final : public FlightService::Service { while (true) { ProtoType pb_value; std::unique_ptr value; - GRPC_RETURN_NOT_OK(iterator->Next(&value)); + GRPC_RETURN_NOT_OK(iterator->Next().Value(&value)); if (!value) { break; } @@ -495,7 +495,7 @@ class GrpcServiceHandler final : public FlightService::Service { while (true) { std::unique_ptr result; - SERVICE_RETURN_NOT_OK(flight_context, results->Next(&result)); + SERVICE_RETURN_NOT_OK(flight_context, results->Next().Value(&result)); if (!result) { // No more results break; @@ -587,9 +587,9 @@ class GrpcServerTransport : public internal::ServerTransport { } if (scheme == kSchemeGrpcTls) { - RETURN_NOT_OK(Location::ForGrpcTls(uri.host(), port, &location_)); + ARROW_ASSIGN_OR_RAISE(location_, Location::ForGrpcTls(uri.host(), port)); } else if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp) { - RETURN_NOT_OK(Location::ForGrpcTcp(uri.host(), port, &location_)); + ARROW_ASSIGN_OR_RAISE(location_, Location::ForGrpcTcp(uri.host(), port)); } return Status::OK(); } diff --git a/cpp/src/arrow/flight/transport_server.cc b/cpp/src/arrow/flight/transport_server.cc index daeebcc2405fd..ecb06dc54d7ac 100644 --- a/cpp/src/arrow/flight/transport_server.cc +++ b/cpp/src/arrow/flight/transport_server.cc @@ -109,31 +109,32 @@ class TransportMessageReader final : public FlightMessageReader { return batch_reader_->schema(); } - Status Next(FlightStreamChunk* out) override { + arrow::Result Next() override { + FlightStreamChunk out; internal::FlightData* data; peekable_reader_->Peek(&data); if (!data) { - out->app_metadata = nullptr; - out->data = nullptr; - return Status::OK(); + out.app_metadata = nullptr; + out.data = nullptr; + return out; } if (!data->metadata) { // Metadata-only (data->metadata is the IPC header) - out->app_metadata = data->app_metadata; - out->data = nullptr; + out.app_metadata = data->app_metadata; + out.data = nullptr; peekable_reader_->Next(&data); - return Status::OK(); + return out; } if (!batch_reader_) { RETURN_NOT_OK(EnsureDataStarted()); // re-peek here since EnsureDataStarted() advances the stream - return Next(out); + return Next(); } - RETURN_NOT_OK(batch_reader_->ReadNext(&out->data)); - out->app_metadata = std::move(app_metadata_); - return Status::OK(); + RETURN_NOT_OK(batch_reader_->ReadNext(&out.data)); + out.app_metadata = std::move(app_metadata_); + return out; } private: @@ -286,8 +287,7 @@ Status ServerTransport::DoGet(const ServerCallContext& context, const Ticket& ti // Consume data stream and write out payloads while (true) { - FlightPayload payload; - RETURN_NOT_OK(data_stream->Next(&payload)); + ARROW_ASSIGN_OR_RAISE(FlightPayload payload, data_stream->Next()); // End of stream if (payload.ipc_message.metadata == nullptr) break; ARROW_ASSIGN_OR_RAISE(auto success, stream->WriteData(payload)); diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 3dc3c1645effb..4a169e985c179 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -142,10 +142,15 @@ Status FlightPayload::Validate() const { return Status::OK(); } +arrow::Result> SchemaResult::GetSchema( + ipc::DictionaryMemo* dictionary_memo) const { + io::BufferReader schema_reader(raw_schema_); + return ipc::ReadSchema(&schema_reader, dictionary_memo); +} + Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo, std::shared_ptr* out) const { - io::BufferReader schema_reader(raw_schema_); - return ipc::ReadSchema(&schema_reader, dictionary_memo).Value(out); + return GetSchema(dictionary_memo).Value(out); } arrow::Result FlightDescriptor::SerializeToString() const { @@ -233,17 +238,20 @@ arrow::Result FlightInfo::Make(const Schema& schema, return FlightInfo(data); } -Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo, - std::shared_ptr* out) const { +arrow::Result> FlightInfo::GetSchema( + ipc::DictionaryMemo* dictionary_memo) const { if (reconstructed_schema_) { - *out = schema_; - return Status::OK(); + return schema_; } io::BufferReader schema_reader(data_.schema); RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo).Value(&schema_)); reconstructed_schema_ = true; - *out = schema_; - return Status::OK(); + return schema_; +} + +Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo, + std::shared_ptr* out) const { + return GetSchema(dictionary_memo).Value(out); } arrow::Result FlightInfo::SerializeToString() const { @@ -284,35 +292,55 @@ Status FlightInfo::Deserialize(const std::string& serialized, Location::Location() { uri_ = std::make_shared(); } +Status FlightListing::Next(std::unique_ptr* info) { + return Next().Value(info); +} + +arrow::Result Location::Parse(const std::string& uri_string) { + Location location; + RETURN_NOT_OK(location.uri_->Parse(uri_string)); + return location; +} + Status Location::Parse(const std::string& uri_string, Location* location) { - return location->uri_->Parse(uri_string); + return Parse(uri_string).Value(location); } -Status Location::ForGrpcTcp(const std::string& host, const int port, Location* location) { +arrow::Result Location::ForGrpcTcp(const std::string& host, const int port) { std::stringstream uri_string; uri_string << "grpc+tcp://" << host << ':' << port; - return Location::Parse(uri_string.str(), location); + return Location::Parse(uri_string.str()); } -Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) { +Status Location::ForGrpcTcp(const std::string& host, const int port, Location* location) { + return ForGrpcTcp(host, port).Value(location); +} + +arrow::Result Location::ForGrpcTls(const std::string& host, const int port) { std::stringstream uri_string; uri_string << "grpc+tls://" << host << ':' << port; - return Location::Parse(uri_string.str(), location); + return Location::Parse(uri_string.str()); } -Status Location::ForGrpcUnix(const std::string& path, Location* location) { +Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) { + return ForGrpcTls(host, port).Value(location); +} + +arrow::Result Location::ForGrpcUnix(const std::string& path) { std::stringstream uri_string; uri_string << "grpc+unix://" << path; - return Location::Parse(uri_string.str(), location); + return Location::Parse(uri_string.str()); +} + +Status Location::ForGrpcUnix(const std::string& path, Location* location) { + return ForGrpcUnix(path).Value(location); } arrow::Result Location::ForScheme(const std::string& scheme, const std::string& host, const int port) { - Location location; std::stringstream uri_string; uri_string << scheme << "://" << host << ':' << port; - RETURN_NOT_OK(Location::Parse(uri_string.str(), &location)); - return location; + return Location::Parse(uri_string.str()); } std::string Location::ToString() const { return uri_->ToString(); } @@ -337,23 +365,36 @@ bool ActionType::Equals(const ActionType& other) const { return type == other.type && description == other.description; } -Status MetadataRecordBatchReader::ReadAll( - std::vector>* batches) { - FlightStreamChunk chunk; +Status ResultStream::Next(std::unique_ptr* info) { return Next().Value(info); } + +Status MetadataRecordBatchReader::Next(FlightStreamChunk* next) { + return Next().Value(next); +} +arrow::Result>> +MetadataRecordBatchReader::ToRecordBatches() { + std::vector> batches; while (true) { - RETURN_NOT_OK(Next(&chunk)); + ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, Next()); if (!chunk.data) break; - batches->emplace_back(std::move(chunk.data)); + batches.emplace_back(std::move(chunk.data)); } - return Status::OK(); + return batches; } -Status MetadataRecordBatchReader::ReadAll(std::shared_ptr
* table) { - std::vector> batches; - RETURN_NOT_OK(ReadAll(&batches)); +Status MetadataRecordBatchReader::ReadAll( + std::vector>* batches) { + return ToRecordBatches().Value(batches); +} + +arrow::Result> MetadataRecordBatchReader::ToTable() { + ARROW_ASSIGN_OR_RAISE(auto batches, ToRecordBatches()); ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema()); - return Table::FromRecordBatches(schema, std::move(batches)).Value(table); + return Table::FromRecordBatches(schema, std::move(batches)); +} + +Status MetadataRecordBatchReader::ReadAll(std::shared_ptr
* table) { + return ToTable().Value(table); } Status MetadataRecordBatchWriter::Begin(const std::shared_ptr& schema) { @@ -368,9 +409,8 @@ class MetadataRecordBatchReaderAdapter : public RecordBatchReader { : schema_(std::move(schema)), delegate_(std::move(delegate)) {} std::shared_ptr schema() const override { return schema_; } Status ReadNext(std::shared_ptr* batch) override { - FlightStreamChunk next; while (true) { - RETURN_NOT_OK(delegate_->Next(&next)); + ARROW_ASSIGN_OR_RAISE(FlightStreamChunk next, delegate_->Next()); if (!next.data && !next.app_metadata) { // EOS *batch = nullptr; @@ -402,25 +442,21 @@ SimpleFlightListing::SimpleFlightListing(const std::vector& flights) SimpleFlightListing::SimpleFlightListing(std::vector&& flights) : position_(0), flights_(std::move(flights)) {} -Status SimpleFlightListing::Next(std::unique_ptr* info) { +arrow::Result> SimpleFlightListing::Next() { if (position_ >= static_cast(flights_.size())) { - *info = nullptr; - return Status::OK(); + return nullptr; } - *info = std::unique_ptr(new FlightInfo(std::move(flights_[position_++]))); - return Status::OK(); + return std::unique_ptr(new FlightInfo(std::move(flights_[position_++]))); } SimpleResultStream::SimpleResultStream(std::vector&& results) : results_(std::move(results)), position_(0) {} -Status SimpleResultStream::Next(std::unique_ptr* result) { +arrow::Result> SimpleResultStream::Next() { if (position_ >= results_.size()) { - *result = nullptr; - return Status::OK(); + return nullptr; } - *result = std::unique_ptr(new Result(std::move(results_[position_++]))); - return Status::OK(); + return std::unique_ptr(new Result(std::move(results_[position_++]))); } arrow::Result BasicAuth::Deserialize(arrow::util::string_view serialized) { diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index a609cacb95ee0..8a77b8fc0c12f 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -299,26 +299,38 @@ struct ARROW_FLIGHT_EXPORT Location { Location(); /// \brief Initialize a location by parsing a URI string + static arrow::Result Parse(const std::string& uri_string); + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") static Status Parse(const std::string& uri_string, Location* location); /// \brief Initialize a location for a non-TLS, gRPC-based Flight /// service from a host and port /// \param[in] host The hostname to connect to /// \param[in] port The port - /// \param[out] location The resulting location + /// \return Arrow result with the resulting location + static arrow::Result ForGrpcTcp(const std::string& host, const int port); + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") static Status ForGrpcTcp(const std::string& host, const int port, Location* location); /// \brief Initialize a location for a TLS-enabled, gRPC-based Flight /// service from a host and port /// \param[in] host The hostname to connect to /// \param[in] port The port - /// \param[out] location The resulting location + /// \return Arrow result with the resulting location + static arrow::Result ForGrpcTls(const std::string& host, const int port); + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") static Status ForGrpcTls(const std::string& host, const int port, Location* location); /// \brief Initialize a location for a domain socket-based Flight /// service /// \param[in] path The path to the domain socket - /// \param[out] location The resulting location + /// \return Arrow result with the resulting location + static arrow::Result ForGrpcUnix(const std::string& path); + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") static Status ForGrpcUnix(const std::string& path, Location* location); /// \brief Initialize a location based on a URI scheme @@ -387,7 +399,11 @@ struct ARROW_FLIGHT_EXPORT SchemaResult { /// \brief return schema /// \param[in,out] dictionary_memo for dictionary bookkeeping, will /// be modified - /// \param[out] out the reconstructed Schema + /// \return Arrrow result with the reconstructed Schema + arrow::Result> GetSchema( + ipc::DictionaryMemo* dictionary_memo) const; + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") Status GetSchema(ipc::DictionaryMemo* dictionary_memo, std::shared_ptr* out) const; @@ -424,7 +440,11 @@ class ARROW_FLIGHT_EXPORT FlightInfo { /// bookkeeping /// \param[in,out] dictionary_memo for dictionary bookkeeping, will /// be modified - /// \param[out] out the reconstructed Schema + /// \return Arrrow result with the reconstructed Schema + arrow::Result> GetSchema( + ipc::DictionaryMemo* dictionary_memo) const; + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") Status GetSchema(ipc::DictionaryMemo* dictionary_memo, std::shared_ptr* out) const; @@ -475,10 +495,12 @@ class ARROW_FLIGHT_EXPORT FlightListing { virtual ~FlightListing() = default; /// \brief Retrieve the next FlightInfo from the iterator. - /// \param[out] info A single FlightInfo. Set to \a nullptr if there + /// \return Arrow result with a single FlightInfo. Set to \a nullptr if there /// are none left. - /// \return Status - virtual Status Next(std::unique_ptr* info) = 0; + virtual arrow::Result> Next() = 0; + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") + Status Next(std::unique_ptr* info); }; /// \brief An iterator to Result instances returned by DoAction. @@ -487,10 +509,11 @@ class ARROW_FLIGHT_EXPORT ResultStream { virtual ~ResultStream() = default; /// \brief Retrieve the next Result from the iterator. - /// \param[out] info A single result. Set to \a nullptr if there - /// are none left. - /// \return Status - virtual Status Next(std::unique_ptr* info) = 0; + /// \return Arrow result with a single Result. Set to \a nullptr if there are none left. + virtual arrow::Result> Next() = 0; + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") + Status Next(std::unique_ptr* info); }; /// \brief A holder for a RecordBatch with associated Flight metadata. @@ -507,14 +530,26 @@ class ARROW_FLIGHT_EXPORT MetadataRecordBatchReader { /// \brief Get the schema for this stream. virtual arrow::Result> GetSchema() = 0; + /// \brief Get the next message from Flight. If the stream is /// finished, then the members of \a FlightStreamChunk will be /// nullptr. - virtual Status Next(FlightStreamChunk* next) = 0; + virtual arrow::Result Next() = 0; + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") + Status Next(FlightStreamChunk* next); + /// \brief Consume entire stream as a vector of record batches - virtual Status ReadAll(std::vector>* batches); + virtual arrow::Result>> ToRecordBatches(); + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use ToRecordBatches instead.") + Status ReadAll(std::vector>* batches); + /// \brief Consume entire stream as a Table - virtual Status ReadAll(std::shared_ptr
* table); + virtual arrow::Result> ToTable(); + + ARROW_DEPRECATED("Deprecated in 8.0.0. Use ToTable instead.") + Status ReadAll(std::shared_ptr
* table); }; /// \brief Convert a MetadataRecordBatchReader to a regular RecordBatchReader. @@ -544,7 +579,7 @@ class ARROW_FLIGHT_EXPORT SimpleFlightListing : public FlightListing { explicit SimpleFlightListing(const std::vector& flights); explicit SimpleFlightListing(std::vector&& flights); - Status Next(std::unique_ptr* info) override; + arrow::Result> Next() override; private: int position_; @@ -558,7 +593,7 @@ class ARROW_FLIGHT_EXPORT SimpleFlightListing : public FlightListing { class ARROW_FLIGHT_EXPORT SimpleResultStream : public ResultStream { public: explicit SimpleResultStream(std::vector&& results); - Status Next(std::unique_ptr* result) override; + arrow::Result> Next() override; private: std::vector results_; diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index 6bc5659c74de5..ba97bc4d44b90 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -206,12 +206,15 @@ PyFlightResultStream::PyFlightResultStream(PyObject* generator, generator_.reset(generator); } -Status PyFlightResultStream::Next(std::unique_ptr* result) { - return SafeCallIntoPython([=] { - const Status status = callback_(generator_.obj(), result); - RETURN_NOT_OK(CheckPyError()); - return status; - }); +arrow::Result> PyFlightResultStream::Next() { + return SafeCallIntoPython( + [=]() -> arrow::Result> { + std::unique_ptr result; + const Status status = callback_(generator_.obj(), &result); + RETURN_NOT_OK(CheckPyError()); + RETURN_NOT_OK(status); + return result; + }); } PyFlightDataStream::PyFlightDataStream( @@ -227,7 +230,7 @@ Status PyFlightDataStream::GetSchemaPayload(FlightPayload* payload) { return stream_->GetSchemaPayload(payload); } -Status PyFlightDataStream::Next(FlightPayload* payload) { return stream_->Next(payload); } +arrow::Result PyFlightDataStream::Next() { return stream_->Next(); } PyGeneratorFlightDataStream::PyGeneratorFlightDataStream( PyObject* generator, std::shared_ptr schema, @@ -243,11 +246,13 @@ Status PyGeneratorFlightDataStream::GetSchemaPayload(FlightPayload* payload) { return ipc::GetSchemaPayload(*schema_, options_, mapper_, &payload->ipc_message); } -Status PyGeneratorFlightDataStream::Next(FlightPayload* payload) { - return SafeCallIntoPython([=] { - const Status status = callback_(generator_.obj(), payload); +arrow::Result PyGeneratorFlightDataStream::Next() { + return SafeCallIntoPython([=]() -> arrow::Result { + FlightPayload payload; + const Status status = callback_(generator_.obj(), &payload); RETURN_NOT_OK(CheckPyError()); - return status; + RETURN_NOT_OK(status); + return payload; }); } diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h index ce8b00669dffe..61ac4051bea2b 100644 --- a/cpp/src/arrow/python/flight.h +++ b/cpp/src/arrow/python/flight.h @@ -186,7 +186,7 @@ class ARROW_PYFLIGHT_EXPORT PyFlightResultStream : public arrow::flight::ResultS /// Must only be called while holding the GIL. explicit PyFlightResultStream(PyObject* generator, PyFlightResultStreamCallback callback); - Status Next(std::unique_ptr* result) override; + arrow::Result> Next() override; private: OwnedRefNoGIL generator_; @@ -204,7 +204,7 @@ class ARROW_PYFLIGHT_EXPORT PyFlightDataStream : public arrow::flight::FlightDat std::shared_ptr schema() override; Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override; - Status Next(arrow::flight::FlightPayload* payload) override; + arrow::Result Next() override; private: OwnedRefNoGIL data_source_; @@ -323,7 +323,7 @@ class ARROW_PYFLIGHT_EXPORT PyGeneratorFlightDataStream const ipc::IpcWriteOptions& options); std::shared_ptr schema() override; Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override; - Status Next(arrow::flight::FlightPayload* payload) override; + arrow::Result Next() override; private: OwnedRefNoGIL generator_; diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index f2330b0e88eb7..bda22f708c7d0 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -525,7 +525,7 @@ cdef class Location(_Weakrefable): CLocation location def __init__(self, uri): - check_flight_status(CLocation.Parse(tobytes(uri), &self.location)) + check_flight_status(CLocation.Parse(tobytes(uri)).Value(&self.location)) def __repr__(self): return ''.format(self.location.ToString()) @@ -550,7 +550,7 @@ cdef class Location(_Weakrefable): int c_port = port Location result = Location.__new__(Location) check_flight_status( - CLocation.ForGrpcTcp(c_host, c_port, &result.location)) + CLocation.ForGrpcTcp(c_host, c_port).Value(&result.location)) return result @staticmethod @@ -561,7 +561,7 @@ cdef class Location(_Weakrefable): int c_port = port Location result = Location.__new__(Location) check_flight_status( - CLocation.ForGrpcTls(c_host, c_port, &result.location)) + CLocation.ForGrpcTls(c_host, c_port).Value(&result.location)) return result @staticmethod @@ -570,7 +570,7 @@ cdef class Location(_Weakrefable): cdef: c_string c_path = tobytes(path) Location result = Location.__new__(Location) - check_flight_status(CLocation.ForGrpcUnix(c_path, &result.location)) + check_flight_status(CLocation.ForGrpcUnix(c_path).Value(&result.location)) return result @staticmethod @@ -584,7 +584,7 @@ cdef class Location(_Weakrefable): cdef CLocation c_location if isinstance(location, str): check_flight_status( - CLocation.Parse(tobytes(location), &c_location)) + CLocation.Parse(tobytes(location)).Value(&c_location)) return c_location elif not isinstance(location, Location): raise TypeError("Must provide a Location, not '{}'".format( @@ -626,7 +626,7 @@ cdef class FlightEndpoint(_Weakrefable): else: c_location = CLocation() check_flight_status( - CLocation.Parse(tobytes(location), &c_location)) + CLocation.Parse(tobytes(location)).Value(&c_location)) self.endpoint.locations.push_back(c_location) @property @@ -671,7 +671,7 @@ cdef class SchemaResult(_Weakrefable): shared_ptr[CSchema] schema CDictionaryMemo dummy_memo - check_flight_status(self.result.get().GetSchema(&dummy_memo, &schema)) + check_flight_status(self.result.get().GetSchema(&dummy_memo).Value(&schema)) return pyarrow_wrap_schema(schema) @@ -731,7 +731,7 @@ cdef class FlightInfo(_Weakrefable): shared_ptr[CSchema] schema CDictionaryMemo dummy_memo - check_flight_status(self.info.get().GetSchema(&dummy_memo, &schema)) + check_flight_status(self.info.get().GetSchema(&dummy_memo).Value(&schema)) return pyarrow_wrap_schema(schema) @property @@ -831,7 +831,7 @@ cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin): cdef: shared_ptr[CTable] c_table with nogil: - check_flight_status(self.reader.get().ReadAll(&c_table)) + check_flight_status(self.reader.get().ToTable().Value(&c_table)) return pyarrow_wrap_table(c_table) def read_chunk(self): @@ -854,7 +854,7 @@ cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin): FlightStreamChunk chunk = FlightStreamChunk() with nogil: - check_flight_status(self.reader.get().Next(&chunk.chunk)) + check_flight_status(self.reader.get().Next().Value(&chunk.chunk)) if chunk.chunk.data == NULL and chunk.chunk.app_metadata == NULL: raise StopIteration @@ -894,7 +894,7 @@ cdef class FlightStreamReader(MetadataRecordBatchReader): with nogil: check_flight_status( ( self.reader.get()) - .ReadAllWithStopToken(&c_table, stop_token)) + .ToTableWithStopToken(stop_token).Value(&c_table)) return pyarrow_wrap_table(c_table) @@ -1294,7 +1294,7 @@ cdef class FlightClient(_Weakrefable): while True: result = Result.__new__(Result) with nogil: - check_flight_status(results.get().Next(&result.result)) + check_flight_status(results.get().Next().Value(&result.result)) if result.result == NULL: break yield result @@ -1323,7 +1323,7 @@ cdef class FlightClient(_Weakrefable): while True: result = FlightInfo.__new__(FlightInfo) with nogil: - check_flight_status(listing.get().Next(&result.info)) + check_flight_status(listing.get().Next().Value(&result.info)) if result.info == NULL: break yield result @@ -1354,7 +1354,7 @@ cdef class FlightClient(_Weakrefable): with nogil: check_status( self.client.get() - .GetSchema(deref(c_options), c_descriptor, &result.result) + .GetSchema(deref(c_options), c_descriptor).Value(&result.result) ) return result @@ -1724,7 +1724,8 @@ cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *: max_attempts = 128 for _ in range(max_attempts): if stream.current_stream != nullptr: - check_flight_status(stream.current_stream.get().Next(payload)) + check_flight_status( + stream.current_stream.get().Next().Value(payload)) # If the stream ended, see if there's another stream from the # generator if payload.ipc_message.metadata != nullptr: diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index ce1da620475da..9d4663e52950b 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -51,7 +51,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CResult[CBasicAuth] Deserialize(const c_string& serialized) cdef cppclass CResultStream" arrow::flight::ResultStream": - CStatus Next(unique_ptr[CFlightResult]* result) + CResult[unique_ptr[CFlightResult]] Next() cdef cppclass CDescriptorType \ " arrow::flight::FlightDescriptor::DescriptorType": @@ -93,16 +93,16 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: c_bool Equals(const CLocation& other) @staticmethod - CStatus Parse(c_string& uri_string, CLocation* location) + CResult[CLocation] Parse(c_string& uri_string) @staticmethod - CStatus ForGrpcTcp(c_string& host, int port, CLocation* location) + CResult[CLocation] ForGrpcTcp(c_string& host, int port) @staticmethod - CStatus ForGrpcTls(c_string& host, int port, CLocation* location) + CResult[CLocation] ForGrpcTls(c_string& host, int port) @staticmethod - CStatus ForGrpcUnix(c_string& path, CLocation* location) + CResult[CLocation] ForGrpcUnix(c_string& path) cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint": CFlightEndpoint() @@ -116,7 +116,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CFlightInfo(CFlightInfo info) int64_t total_records() int64_t total_bytes() - CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out) + CResult[shared_ptr[CSchema]] GetSchema(CDictionaryMemo* memo) CFlightDescriptor& descriptor() const vector[CFlightEndpoint]& endpoints() CResult[c_string] SerializeToString() @@ -127,10 +127,10 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CSchemaResult" arrow::flight::SchemaResult": CSchemaResult(CSchemaResult result) - CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out) + CResult[shared_ptr[CSchema]] GetSchema(CDictionaryMemo* memo) cdef cppclass CFlightListing" arrow::flight::FlightListing": - CStatus Next(unique_ptr[CFlightInfo]* info) + CResult[unique_ptr[CFlightInfo]] Next() cdef cppclass CSimpleFlightListing" arrow::flight::SimpleFlightListing": CSimpleFlightListing(vector[CFlightInfo]&& info) @@ -142,7 +142,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CFlightDataStream" arrow::flight::FlightDataStream": shared_ptr[CSchema] schema() - CStatus Next(CFlightPayload*) + CResult[CFlightPayload] Next() cdef cppclass CFlightStreamChunk" arrow::flight::FlightStreamChunk": CFlightStreamChunk() @@ -152,8 +152,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CMetadataRecordBatchReader \ " arrow::flight::MetadataRecordBatchReader": CResult[shared_ptr[CSchema]] GetSchema() - CStatus Next(CFlightStreamChunk* out) - CStatus ReadAll(shared_ptr[CTable]* table) + CResult[CFlightStreamChunk] Next() + CResult[shared_ptr[CTable]] ToTable() CResult[shared_ptr[CRecordBatchReader]] MakeRecordBatchReader\ " arrow::flight::MakeRecordBatchReader"( @@ -170,8 +170,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CFlightStreamReader \ " arrow::flight::FlightStreamReader"(CMetadataRecordBatchReader): void Cancel() - CStatus ReadAllWithStopToken" ReadAll"\ - (shared_ptr[CTable]* table, const CStopToken& stop_token) + CResult[shared_ptr[CTable]] ToTableWithStopToken" ToTable"\ + (const CStopToken& stop_token) cdef cppclass CFlightMessageReader \ " arrow::flight::FlightMessageReader"(CMetadataRecordBatchReader): @@ -337,9 +337,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CStatus GetFlightInfo(CFlightCallOptions& options, CFlightDescriptor& descriptor, unique_ptr[CFlightInfo]* info) - CStatus GetSchema(CFlightCallOptions& options, - CFlightDescriptor& descriptor, - unique_ptr[CSchemaResult]* result) + CResult[unique_ptr[CSchemaResult]] GetSchema(CFlightCallOptions& options, + CFlightDescriptor& descriptor) CStatus DoGet(CFlightCallOptions& options, CTicket& ticket, unique_ptr[CFlightStreamReader]* stream) CStatus DoPut(CFlightCallOptions& options,