From 9b55daa8abc0efb1978bd8baaa230bb781c7d82c Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Thu, 9 Dec 2021 13:45:08 -0300 Subject: [PATCH 1/8] Separate Flight integration tests into new modules This refactor is a prerequisite to adding Flight SQL integration tests, as the current integration tests do not allow Flight SQL usage (because it would cause circular dependencies) --- cpp/CMakeLists.txt | 4 + cpp/src/arrow/CMakeLists.txt | 3 + cpp/src/arrow/flight/CMakeLists.txt | 17 +--- .../flight/integration_tests/CMakeLists.txt | 38 ++++++++ .../test_integration.cc | 6 +- .../test_integration.h | 5 ++ .../test_integration_client.cc | 16 ++-- .../test_integration_server.cc | 12 ++- cpp/src/arrow/flight/sql/test_server_cli.cc | 3 +- .../archery/integration/tester_java.py | 6 +- java/flight/flight-core/pom.xml | 5 -- java/flight/flight-integration-tests/pom.xml | 86 +++++++++++++++++++ .../tests}/AuthBasicProtoScenario.java | 2 +- .../tests}/IntegrationAssertions.java | 2 +- .../tests}/IntegrationTestClient.java | 2 +- .../tests}/IntegrationTestServer.java | 2 +- .../tests}/MiddlewareScenario.java | 2 +- .../flight/integration/tests}/Scenario.java | 2 +- .../flight/integration/tests}/Scenarios.java | 2 +- java/flight/pom.xml | 1 + 20 files changed, 172 insertions(+), 44 deletions(-) create mode 100644 cpp/src/arrow/flight/integration_tests/CMakeLists.txt rename cpp/src/arrow/flight/{ => integration_tests}/test_integration.cc (98%) rename cpp/src/arrow/flight/{ => integration_tests}/test_integration.h (96%) rename cpp/src/arrow/flight/{ => integration_tests}/test_integration_client.cc (94%) rename cpp/src/arrow/flight/{ => integration_tests}/test_integration_server.cc (94%) create mode 100644 java/flight/flight-integration-tests/pom.xml rename java/flight/{flight-core/src/main/java/org/apache/arrow/flight/example/integration => flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests}/AuthBasicProtoScenario.java (98%) rename java/flight/{flight-core/src/main/java/org/apache/arrow/flight/example/integration => flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests}/IntegrationAssertions.java (97%) rename java/flight/{flight-core/src/main/java/org/apache/arrow/flight/example/integration => flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests}/IntegrationTestClient.java (99%) rename java/flight/{flight-core/src/main/java/org/apache/arrow/flight/example/integration => flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests}/IntegrationTestServer.java (98%) rename java/flight/{flight-core/src/main/java/org/apache/arrow/flight/example/integration => flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests}/MiddlewareScenario.java (99%) rename java/flight/{flight-core/src/main/java/org/apache/arrow/flight/example/integration => flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests}/Scenario.java (96%) rename java/flight/{flight-core/src/main/java/org/apache/arrow/flight/example/integration => flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests}/Scenarios.java (98%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 3de8ff7656926..fd7027c30ebe2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -334,6 +334,10 @@ if(ARROW_GANDIVA) set(ARROW_WITH_RE2 ON) endif() +if(ARROW_BUILD_INTEGRATION AND ARROW_FLIGHT) + set(ARROW_FLIGHT_SQL ON) +endif() + if(ARROW_FLIGHT_SQL) set(ARROW_FLIGHT ON) endif() diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 502629a92a4eb..82eb7c191006f 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -730,6 +730,9 @@ endif() if(ARROW_FLIGHT) add_subdirectory(flight) + if(ARROW_BUILD_INTEGRATION) + add_subdirectory(flight/integration_tests) + endif() endif() if(ARROW_FLIGHT_SQL) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 55e89b2eb99e5..1a694a4abb5a0 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -202,7 +202,7 @@ if(ARROW_TESTING) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES - test_integration.cc + integration_tests/test_integration.cc test_util.cc DEPENDENCIES GTest::gtest @@ -246,21 +246,6 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) add_dependencies(arrow_flight flight-test-server) endif() -if(ARROW_BUILD_INTEGRATION) - add_executable(flight-test-integration-server test_integration_server.cc) - target_link_libraries(flight-test-integration-server ${ARROW_FLIGHT_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES} GTest::gtest) - - add_executable(flight-test-integration-client test_integration_client.cc) - target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES} GTest::gtest) - - add_dependencies(arrow_flight flight-test-integration-client - flight-test-integration-server) - add_dependencies(arrow-integration flight-test-integration-client - flight-test-integration-server) -endif() - if(ARROW_BUILD_BENCHMARKS) # Perf server for benchmarks set(PERF_PROTO_GENERATED_FILES "${CMAKE_CURRENT_BINARY_DIR}/perf.pb.cc" diff --git a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt new file mode 100644 index 0000000000000..7f45dd2adf1ff --- /dev/null +++ b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(arrow_flight_integration_tests) + +if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + set(ARROW_FLIGHT_TEST_LINK_LIBS + arrow_flight_static arrow_flight_testing_static ${ARROW_FLIGHT_STATIC_LINK_LIBS} + ${ARROW_TEST_LINK_LIBS}) +else() + set(ARROW_FLIGHT_TEST_LINK_LIBS arrow_flight_shared arrow_flight_testing_shared + ${ARROW_TEST_LINK_LIBS}) +endif() + +add_executable(flight-test-integration-server test_integration_server.cc) +target_link_libraries(flight-test-integration-server ${ARROW_FLIGHT_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES} GTest::gtest) + +add_executable(flight-test-integration-client test_integration_client.cc) +target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES} GTest::gtest) + +add_dependencies(arrow-integration flight-test-integration-client + flight-test-integration-server) diff --git a/cpp/src/arrow/flight/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc similarity index 98% rename from cpp/src/arrow/flight/test_integration.cc rename to cpp/src/arrow/flight/integration_tests/test_integration.cc index 29ce5601f375b..491aeca305006 100644 --- a/cpp/src/arrow/flight/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/flight/test_integration.h" +#include "arrow/flight/integration_tests/test_integration.h" #include "arrow/flight/client_middleware.h" #include "arrow/flight/server_middleware.h" #include "arrow/flight/test_util.h" @@ -30,6 +30,7 @@ namespace arrow { namespace flight { +namespace integration_tests { /// \brief The server for the basic auth integration test. class AuthBasicProtoServer : public FlightServerBase { @@ -113,9 +114,11 @@ class AuthBasicProtoScenario : public Scenario { class TestServerMiddleware : public ServerMiddleware { public: explicit TestServerMiddleware(std::string received) : received_(received) {} + void SendingHeaders(AddCallHeaders* outgoing_headers) override { outgoing_headers->AddHeader("x-middleware", received_); } + void CallCompleted(const Status& status) override {} std::string name() const override { return "GrpcTrailersMiddleware"; } @@ -266,5 +269,6 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr* return Status::KeyError("Scenario not found: ", scenario_name); } +} // namespace integration_tests } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration.h b/cpp/src/arrow/flight/integration_tests/test_integration.h similarity index 96% rename from cpp/src/arrow/flight/test_integration.h rename to cpp/src/arrow/flight/integration_tests/test_integration.h index 5d9bd7fd7bd74..cc598cff3a2ba 100644 --- a/cpp/src/arrow/flight/test_integration.h +++ b/cpp/src/arrow/flight/integration_tests/test_integration.h @@ -28,16 +28,20 @@ namespace arrow { namespace flight { +namespace integration_tests { /// \brief An integration test for Arrow Flight. class ARROW_FLIGHT_EXPORT Scenario { public: virtual ~Scenario() = default; + /// \brief Set up the server. virtual Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) = 0; + /// \brief Set up the client. virtual Status MakeClient(FlightClientOptions* options) = 0; + /// \brief Run the scenario as the client. virtual Status RunClient(std::unique_ptr client) = 0; }; @@ -45,5 +49,6 @@ class ARROW_FLIGHT_EXPORT Scenario { /// \brief Get the implementation of an integration test scenario by name. Status GetScenario(const std::string& scenario_name, std::shared_ptr* out); +} // namespace integration_tests } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration_client.cc b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc similarity index 94% rename from cpp/src/arrow/flight/test_integration_client.cc rename to cpp/src/arrow/flight/integration_tests/test_integration_client.cc index 6c1d69046037f..366284389f104 100644 --- a/cpp/src/arrow/flight/test_integration_client.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc @@ -41,7 +41,7 @@ #include "arrow/util/logging.h" #include "arrow/flight/api.h" -#include "arrow/flight/test_integration.h" +#include "arrow/flight/integration_tests/test_integration.h" #include "arrow/flight/test_util.h" DEFINE_string(host, "localhost", "Server port to connect to"); @@ -51,6 +51,7 @@ DEFINE_string(scenario, "", "Integration test scenario to run"); namespace arrow { namespace flight { +namespace integration_tests { /// \brief Helper to read all batches from a JsonReader Status ReadBatches(std::unique_ptr& reader, @@ -133,7 +134,7 @@ Status ConsumeFlightLocation( return Status::OK(); } -class IntegrationTestScenario : public flight::Scenario { +class IntegrationTestScenario : public Scenario { public: Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) override { @@ -201,12 +202,13 @@ class IntegrationTestScenario : public flight::Scenario { } }; +} // namespace integration_tests } // namespace flight } // namespace arrow constexpr int kRetries = 3; -arrow::Status RunScenario(arrow::flight::Scenario* scenario) { +arrow::Status RunScenario(arrow::flight::integration_tests::Scenario* scenario) { auto options = arrow::flight::FlightClientOptions::Defaults(); std::unique_ptr client; @@ -222,11 +224,13 @@ int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing client for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); - std::shared_ptr scenario; + std::shared_ptr scenario; if (!FLAGS_scenario.empty()) { - ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario)); + ARROW_CHECK_OK( + arrow::flight::integration_tests::GetScenario(FLAGS_scenario, &scenario)); } else { - scenario = std::make_shared(); + scenario = + std::make_shared(); } // ARROW-11908: retry a few times in case a client is slow to bring up the server diff --git a/cpp/src/arrow/flight/test_integration_server.cc b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc similarity index 94% rename from cpp/src/arrow/flight/test_integration_server.cc rename to cpp/src/arrow/flight/integration_tests/test_integration_server.cc index 4b904b0eba13a..92b2241a872b6 100644 --- a/cpp/src/arrow/flight/test_integration_server.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc @@ -34,10 +34,10 @@ #include "arrow/testing/json_integration.h" #include "arrow/util/logging.h" +#include "arrow/flight/integration_tests/test_integration.h" #include "arrow/flight/internal.h" #include "arrow/flight/server.h" #include "arrow/flight/server_auth.h" -#include "arrow/flight/test_integration.h" #include "arrow/flight/test_util.h" DEFINE_int32(port, 31337, "Server port to listen on"); @@ -45,6 +45,7 @@ DEFINE_string(scenario, "", "Integration test senario to run"); namespace arrow { namespace flight { +namespace integration_tests { struct IntegrationDataset { std::shared_ptr schema; @@ -175,6 +176,7 @@ class IntegrationTestScenario : public Scenario { } }; +} // namespace integration_tests } // namespace flight } // namespace arrow @@ -184,12 +186,14 @@ int main(int argc, char** argv) { gflags::SetUsageMessage("Integration testing server for Flight."); gflags::ParseCommandLineFlags(&argc, &argv, true); - std::shared_ptr scenario; + std::shared_ptr scenario; if (!FLAGS_scenario.empty()) { - ARROW_CHECK_OK(arrow::flight::GetScenario(FLAGS_scenario, &scenario)); + ARROW_CHECK_OK( + arrow::flight::integration_tests::GetScenario(FLAGS_scenario, &scenario)); } else { - scenario = std::make_shared(); + scenario = + std::make_shared(); } arrow::flight::Location location; ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); diff --git a/cpp/src/arrow/flight/sql/test_server_cli.cc b/cpp/src/arrow/flight/sql/test_server_cli.cc index 8074ab534bd24..b696347359e88 100644 --- a/cpp/src/arrow/flight/sql/test_server_cli.cc +++ b/cpp/src/arrow/flight/sql/test_server_cli.cc @@ -20,11 +20,10 @@ #include #include #include +#include #include "arrow/flight/server.h" #include "arrow/flight/sql/example/sqlite_server.h" -#include "arrow/flight/test_integration.h" -#include "arrow/flight/test_util.h" #include "arrow/io/test_common.h" #include "arrow/testing/json_integration.h" #include "arrow/util/logging.h" diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py index 5104a0cc75557..75875ad7185c3 100644 --- a/dev/archery/archery/integration/tester_java.py +++ b/dev/archery/archery/integration/tester_java.py @@ -49,11 +49,11 @@ class JavaTester(Tester): ARROW_FLIGHT_JAR = os.environ.get( 'ARROW_FLIGHT_JAVA_INTEGRATION_JAR', os.path.join(ARROW_ROOT_DEFAULT, - 'java/flight/flight-core/target/flight-core-{}-' + 'java/flight/flight-integration-tests/target/flight-integration-tests-{}-' 'jar-with-dependencies.jar'.format(_arrow_version))) - ARROW_FLIGHT_SERVER = ('org.apache.arrow.flight.example.integration.' + ARROW_FLIGHT_SERVER = ('org.apache.arrow.flight.integration.tests.' 'IntegrationTestServer') - ARROW_FLIGHT_CLIENT = ('org.apache.arrow.flight.example.integration.' + ARROW_FLIGHT_CLIENT = ('org.apache.arrow.flight.integration.tests.' 'IntegrationTestClient') name = 'Java' diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index c8ab5ac1d26d4..d870faf9c50f1 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -93,11 +93,6 @@ com.google.guava guava - - commons-cli - commons-cli - 1.4 - io.grpc grpc-stub diff --git a/java/flight/flight-integration-tests/pom.xml b/java/flight/flight-integration-tests/pom.xml new file mode 100644 index 0000000000000..2bd9a9f4e04b1 --- /dev/null +++ b/java/flight/flight-integration-tests/pom.xml @@ -0,0 +1,86 @@ + + + + 4.0.0 + + arrow-flight + org.apache.arrow + 7.0.0-SNAPSHOT + ../pom.xml + + + flight-integration-tests + Arrow Flight Integration Tests + 7.0.0-SNAPSHOT + jar + + + + org.apache.arrow + arrow-vector + ${project.version} + + + org.apache.arrow + arrow-memory-core + ${project.version} + + + org.apache.arrow + flight-core + ${project.version} + + + org.apache.arrow + flight-sql + ${project.version} + + + com.google.protobuf + protobuf-java + ${dep.protobuf.version} + + + commons-cli + commons-cli + 1.4 + + + org.slf4j + slf4j-api + + + + + + + maven-assembly-plugin + 3.0.0 + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + + diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java similarity index 98% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java index 3955d7d21bfcd..1c95d4d5593c9 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/AuthBasicProtoScenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.nio.charset.StandardCharsets; import java.util.Arrays; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java similarity index 97% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java index 576d1887f3905..993ce73f7fe04 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.util.Objects; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java similarity index 99% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java index 27a545f84fd5b..5ed8e70ea63f1 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java similarity index 98% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java index da336c5024aa2..7f5e15fe37669 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestServer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java similarity index 99% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java index c710ce98b563e..c284a577c08f2 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/MiddlewareScenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.nio.charset.StandardCharsets; import java.util.Arrays; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java similarity index 96% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java index b3b962d2e734b..bcc657b765c77 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenario.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightProducer; diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java similarity index 98% rename from java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java rename to java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index cd9859b4f361b..10882ee0a8603 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.arrow.flight.example.integration; +package org.apache.arrow.flight.integration.tests; import java.util.Map; import java.util.TreeMap; diff --git a/java/flight/pom.xml b/java/flight/pom.xml index 2cb409aaad0dd..7cb0e1d717128 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -33,6 +33,7 @@ flight-core flight-grpc flight-sql + flight-integration-tests From 01210129e8d1aacab36a31ccdb6390650b919d45 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Wed, 15 Dec 2021 13:41:03 -0300 Subject: [PATCH 2/8] Create integration tests for Flight SQL Java and C++ --- cpp/src/arrow/CMakeLists.txt | 7 +- cpp/src/arrow/flight/CMakeLists.txt | 1 - .../flight/integration_tests/CMakeLists.txt | 19 +- .../integration_tests/test_integration.cc | 391 ++++++++++++++++++ cpp/src/arrow/flight/sql/server.cc | 21 +- cpp/src/arrow/flight/sql/server_test.cc | 2 +- cpp/src/arrow/flight/sql/test_server_cli.cc | 2 +- dev/archery/archery/integration/runner.py | 3 + .../integration/tests/FlightSqlScenario.java | 128 ++++++ .../tests/FlightSqlScenarioProducer.java | 343 +++++++++++++++ .../tests/IntegrationAssertions.java | 9 + .../tests/IntegrationTestClient.java | 12 +- .../flight/integration/tests/Scenarios.java | 1 + .../arrow/flight/sql/FlightSqlProducer.java | 31 +- .../flight/sql/example/FlightSqlExample.java | 10 +- 15 files changed, 933 insertions(+), 47 deletions(-) create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 82eb7c191006f..aeb1e51337eaf 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -730,15 +730,16 @@ endif() if(ARROW_FLIGHT) add_subdirectory(flight) - if(ARROW_BUILD_INTEGRATION) - add_subdirectory(flight/integration_tests) - endif() endif() if(ARROW_FLIGHT_SQL) add_subdirectory(flight/sql) endif() +if(ARROW_FLIGHT AND ARROW_BUILD_INTEGRATION) + add_subdirectory(flight/integration_tests) +endif() + if(ARROW_HIVESERVER2) add_subdirectory(dbi/hiveserver2) endif() diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 1a694a4abb5a0..2cf8c9913e572 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -202,7 +202,6 @@ if(ARROW_TESTING) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES - integration_tests/test_integration.cc test_util.cc DEPENDENCIES GTest::gtest diff --git a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt index 7f45dd2adf1ff..3a878d7f30587 100644 --- a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt +++ b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt @@ -19,18 +19,27 @@ add_custom_target(arrow_flight_integration_tests) if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") set(ARROW_FLIGHT_TEST_LINK_LIBS - arrow_flight_static arrow_flight_testing_static ${ARROW_FLIGHT_STATIC_LINK_LIBS} + arrow_flight_static + arrow_flight_testing_static + arrow_flight_sql_static + ${ARROW_FLIGHT_STATIC_LINK_LIBS} ${ARROW_TEST_LINK_LIBS}) else() - set(ARROW_FLIGHT_TEST_LINK_LIBS arrow_flight_shared arrow_flight_testing_shared - ${ARROW_TEST_LINK_LIBS}) + set(ARROW_FLIGHT_TEST_LINK_LIBS + arrow_flight_shared + arrow_flight_testing_shared + arrow_flight_sql_shared + ${ARROW_TEST_LINK_LIBS} + ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS}) endif() -add_executable(flight-test-integration-server test_integration_server.cc) +add_executable(flight-test-integration-server test_integration_server.cc + test_integration.cc) target_link_libraries(flight-test-integration-server ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) -add_executable(flight-test-integration-client test_integration_client.cc) +add_executable(flight-test-integration-client test_integration_client.cc + test_integration.cc) target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 491aeca305006..4456befbcf738 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -18,10 +18,14 @@ #include "arrow/flight/integration_tests/test_integration.h" #include "arrow/flight/client_middleware.h" #include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/server.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" +#include +#include #include #include #include @@ -258,6 +262,390 @@ class MiddlewareScenario : public Scenario { std::shared_ptr client_middleware_; }; +std::shared_ptr GetQuerySchema() { + return arrow::schema({arrow::field("id", int64())}); +} + +template +arrow::Status AssertEq(const T& expected, const T& actual) { + if (expected != actual) { + return Status::Invalid("Expected \"", expected, "\", got \'", actual, "\""); + } + return Status::OK(); +} + +/// \brief The server used for testing Flight SQL, this implements a static Flight SQL server which only asserts +/// that commands called during integration tests are being parsed correctly and returns the expected schemas to be +/// validated on client. +class FlightSqlScenarioServer : public sql::FlightSqlServerBase { + public: + arrow::Result> GetFlightInfoStatement( + const ServerCallContext& context, const sql::StatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT STATEMENT", command.query)); + + ARROW_ASSIGN_OR_RAISE(auto handle, + sql::CreateStatementQueryTicket("SELECT STATEMENT HANDLE")); + + std::vector endpoints{FlightEndpoint{{handle}, {}}}; + ARROW_ASSIGN_OR_RAISE( + auto result, FlightInfo::Make(*GetQuerySchema(), descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); + } + + arrow::Result> DoGetStatement( + const ServerCallContext& context, + const sql::StatementQueryTicket& command) override { + return DoGetForTestCase(GetQuerySchema()); + } + + arrow::Result> GetFlightInfoPreparedStatement( + const ServerCallContext& context, const sql::PreparedStatementQuery& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + return GetFlightInfoForCommand(descriptor, GetQuerySchema()); + } + + arrow::Result> DoGetPreparedStatement( + const ServerCallContext& context, + const sql::PreparedStatementQuery& command) override { + return DoGetForTestCase(GetQuerySchema()); + } + + arrow::Result> GetFlightInfoCatalogs( + const ServerCallContext& context, const FlightDescriptor& descriptor) override { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetCatalogsSchema()); + } + + arrow::Result> DoGetCatalogs( + const ServerCallContext& context) override { + return DoGetForTestCase(sql::SqlSchema::GetCatalogsSchema()); + } + + arrow::Result> GetFlightInfoSqlInfo( + const ServerCallContext& context, const sql::GetSqlInfo& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq(2, command.info.size())); + ARROW_RETURN_NOT_OK(AssertEq( + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, command.info[0])); + ARROW_RETURN_NOT_OK(AssertEq( + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, command.info[1])); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema()); + } + + arrow::Result> DoGetSqlInfo( + const ServerCallContext& context, const sql::GetSqlInfo& command) override { + return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema()); + } + + arrow::Result> GetFlightInfoSchemas( + const ServerCallContext& context, const sql::GetDbSchemas& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value())); + ARROW_RETURN_NOT_OK(AssertEq("db_schema_filter_pattern", + command.db_schema_filter_pattern.value())); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetDbSchemasSchema()); + } + + arrow::Result> DoGetDbSchemas( + const ServerCallContext& context, const sql::GetDbSchemas& command) override { + return DoGetForTestCase(sql::SqlSchema::GetDbSchemasSchema()); + } + + arrow::Result> GetFlightInfoTables( + const ServerCallContext& context, const sql::GetTables& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK(AssertEq("catalog", command.catalog.value())); + ARROW_RETURN_NOT_OK(AssertEq("db_schema_filter_pattern", + command.db_schema_filter_pattern.value())); + ARROW_RETURN_NOT_OK(AssertEq("table_filter_pattern", + command.table_name_filter_pattern.value())); + ARROW_RETURN_NOT_OK(AssertEq(2, command.table_types.size())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_types[0])); + ARROW_RETURN_NOT_OK(AssertEq("view", command.table_types[1])); + ARROW_RETURN_NOT_OK(AssertEq(true, command.include_schema)); + + return GetFlightInfoForCommand(descriptor, + sql::SqlSchema::GetTablesSchemaWithIncludedSchema()); + } + + arrow::Result> DoGetTables( + const ServerCallContext& context, const sql::GetTables& command) override { + return DoGetForTestCase(sql::SqlSchema::GetTablesSchemaWithIncludedSchema()); + } + + arrow::Result> GetFlightInfoTableTypes( + const ServerCallContext& context, const FlightDescriptor& descriptor) override { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> DoGetTableTypes( + const ServerCallContext& context) override { + return DoGetForTestCase(sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> GetFlightInfoPrimaryKeys( + const ServerCallContext& context, const sql::GetPrimaryKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetPrimaryKeysSchema()); + } + + arrow::Result> DoGetPrimaryKeys( + const ServerCallContext& context, const sql::GetPrimaryKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetPrimaryKeysSchema()); + } + + arrow::Result> GetFlightInfoExportedKeys( + const ServerCallContext& context, const sql::GetExportedKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetExportedKeysSchema()); + } + + arrow::Result> DoGetExportedKeys( + const ServerCallContext& context, const sql::GetExportedKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetExportedKeysSchema()); + } + + arrow::Result> GetFlightInfoImportedKeys( + const ServerCallContext& context, const sql::GetImportedKeys& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("catalog", command.table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("db_schema", command.table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("table", command.table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetImportedKeysSchema()); + } + + arrow::Result> DoGetImportedKeys( + const ServerCallContext& context, const sql::GetImportedKeys& command) override { + return DoGetForTestCase(sql::SqlSchema::GetImportedKeysSchema()); + } + + arrow::Result> GetFlightInfoCrossReference( + const ServerCallContext& context, const sql::GetCrossReference& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq("pk_catalog", command.pk_table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("pk_db_schema", command.pk_table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("pk_table", command.pk_table_ref.table)); + ARROW_RETURN_NOT_OK( + AssertEq("fk_catalog", command.fk_table_ref.catalog.value())); + ARROW_RETURN_NOT_OK( + AssertEq("fk_db_schema", command.fk_table_ref.db_schema.value())); + ARROW_RETURN_NOT_OK(AssertEq("fk_table", command.fk_table_ref.table)); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema()); + } + + arrow::Result> DoGetCrossReference( + const ServerCallContext& context, const sql::GetCrossReference& command) override { + return DoGetForTestCase(sql::SqlSchema::GetCrossReferenceSchema()); + } + + arrow::Result DoPutCommandStatementUpdate( + const ServerCallContext& context, const sql::StatementUpdate& command) override { + ARROW_RETURN_NOT_OK(AssertEq("UPDATE STATEMENT", command.query)); + + return 10000; + } + + arrow::Result CreatePreparedStatement( + const ServerCallContext& context, + const sql::ActionCreatePreparedStatementRequest& request) override { + ARROW_RETURN_NOT_OK( + AssertEq(true, request.query == "SELECT PREPARED STATEMENT" || + request.query == "UPDATE PREPARED STATEMENT")); + + sql::ActionCreatePreparedStatementResult result; + result.prepared_statement_handle = request.query + " HANDLE"; + + return result; + } + + Status ClosePreparedStatement( + const ServerCallContext& context, + const sql::ActionClosePreparedStatementRequest& request) override { + return Status::OK(); + } + + Status DoPutPreparedStatementQuery(const ServerCallContext& context, + const sql::PreparedStatementQuery& command, + FlightMessageReader* reader, + FlightMetadataWriter* writer) override { + return Status::NotImplemented("Not implemented"); + } + + arrow::Result DoPutPreparedStatementUpdate( + const ServerCallContext& context, const sql::PreparedStatementUpdate& command, + FlightMessageReader* reader) override { + ARROW_RETURN_NOT_OK(AssertEq("UPDATE PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + return 20000; + } + + private: + arrow::Result> GetFlightInfoForCommand( + const FlightDescriptor& descriptor, const std::shared_ptr& schema) { + std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) + + return std::unique_ptr(new FlightInfo(result)); + } + + arrow::Result> DoGetForTestCase( + std::shared_ptr schema) { + ARROW_ASSIGN_OR_RAISE(auto reader2, RecordBatchReader::Make({}, schema)); + return std::unique_ptr(new RecordBatchStream(reader2)); + } +}; + +/// \brief Integration test scenario for validating Flight SQL specs across multiple +/// implementations. This should ensure that RPC objects are being built and parsed +/// correctly for multiple languages and that the Arrow schemas are returned as expected. +class FlightSqlScenario : public Scenario { + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + server->reset(new FlightSqlScenarioServer()); + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } + + Status Validate(std::shared_ptr expectedSchema, + arrow::Result> flightInfo, + sql::FlightSqlClient* sql_client) { + FlightCallOptions call_options; + + ARROW_ASSIGN_OR_RAISE(auto flight_info, flightInfo); + ARROW_ASSIGN_OR_RAISE( + auto reader, sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); + + ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + + AssertSchemaEqual(expectedSchema, actual_schema); + + return Status::OK(); + } + + Status RunClient(std::unique_ptr client) override { + sql::FlightSqlClient sql_client(std::move(client)); + + ARROW_RETURN_NOT_OK(ValidateMetadataRetrieval(&sql_client)); + + ARROW_RETURN_NOT_OK(ValidateStatementExecution(&sql_client)); + + ARROW_RETURN_NOT_OK(ValidatePreparedStatementExecution(&sql_client)); + + return Status::OK(); + } + + Status ValidateMetadataRetrieval(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + std::string catalog = "catalog"; + std::string db_schema_filter_pattern = "db_schema_filter_pattern"; + std::string table_filter_pattern = "table_filter_pattern"; + std::string table = "table"; + std::string db_schema = "db_schema"; + std::vector table_types = {"table", "view"}; + + sql::TableRef table_ref = {catalog, db_schema, table}; + sql::TableRef pk_table_ref = {"pk_catalog", "pk_db_schema", "pk_table"}; + sql::TableRef fk_table_ref = {"fk_catalog", "fk_db_schema", "fk_table"}; + + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetCatalogsSchema(), + sql_client->GetCatalogs(options), sql_client)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetDbSchemasSchema(), + sql_client->GetDbSchemas(options, &catalog, &db_schema_filter_pattern), + sql_client)); + ARROW_RETURN_NOT_OK( + Validate(sql::SqlSchema::GetTablesSchemaWithIncludedSchema(), + sql_client->GetTables(options, &catalog, &db_schema_filter_pattern, + &table_filter_pattern, true, &table_types), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetTableTypesSchema(), + sql_client->GetTableTypes(options), sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetPrimaryKeysSchema(), + sql_client->GetPrimaryKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetExportedKeysSchema(), + sql_client->GetExportedKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate(sql::SqlSchema::GetImportedKeysSchema(), + sql_client->GetImportedKeys(options, table_ref), + sql_client)); + ARROW_RETURN_NOT_OK(Validate( + sql::SqlSchema::GetCrossReferenceSchema(), + sql_client->GetCrossReference(options, pk_table_ref, fk_table_ref), sql_client)); + ARROW_RETURN_NOT_OK(Validate( + sql::SqlSchema::GetSqlInfoSchema(), + sql_client->GetSqlInfo( + options, {sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, + sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY}), + sql_client)); + + return Status::OK(); + } + + Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + ARROW_RETURN_NOT_OK(Validate( + GetQuerySchema(), sql_client->Execute(options, "SELECT STATEMENT"), sql_client)); + ARROW_ASSIGN_OR_RAISE(auto update_statement_result, + sql_client->ExecuteUpdate(options, "UPDATE STATEMENT")); + if (update_statement_result != 10000L) { + return Status::Invalid("Expected 'UPDATE STATEMENT' return 10000, got ", + update_statement_result); + } + + return Status::OK(); + } + + Status ValidatePreparedStatementExecution(sql::FlightSqlClient* sql_client) { + FlightCallOptions options; + + ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement, + sql_client->Prepare(options, "SELECT PREPARED STATEMENT")); + ARROW_RETURN_NOT_OK( + Validate(GetQuerySchema(), select_prepared_statement->Execute(), sql_client)); + + ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement, + sql_client->Prepare(options, "UPDATE PREPARED STATEMENT")); + ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement_result, + update_prepared_statement->ExecuteUpdate()); + if (update_prepared_statement_result != 20000L) { + return Status::Invalid("Expected 'UPDATE STATEMENT' return 20000, got ", + update_prepared_statement_result); + } + + return Status::OK(); + } +}; + Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { if (scenario_name == "auth:basic_proto") { *out = std::make_shared(); @@ -265,6 +653,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr* } else if (scenario_name == "middleware") { *out = std::make_shared(); return Status::OK(); + } else if (scenario_name == "flight_sql") { + *out = std::make_shared(); + return Status::OK(); } return Status::KeyError("Scenario not found: ", scenario_name); } diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 6d328c07b0e69..bbbe801ea24d5 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -689,7 +689,7 @@ arrow::Result FlightSqlServerBase::DoPutCommandStatementUpdate( } std::shared_ptr SqlSchema::GetCatalogsSchema() { - return arrow::schema({field("catalog_name", utf8())}); + return arrow::schema({field("catalog_name", utf8(), false)}); } std::shared_ptr SqlSchema::GetDbSchemasSchema() { @@ -699,23 +699,26 @@ std::shared_ptr SqlSchema::GetDbSchemasSchema() { std::shared_ptr SqlSchema::GetTablesSchema() { return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), - field("table_name", utf8()), field("table_type", utf8())}); + field("table_name", utf8(), false), + field("table_type", utf8(), false)}); } std::shared_ptr SqlSchema::GetTablesSchemaWithIncludedSchema() { return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), - field("table_name", utf8()), field("table_type", utf8()), - field("table_schema", binary())}); + field("table_name", utf8(), false), + field("table_type", utf8(), false), + field("table_schema", binary(), false)}); } std::shared_ptr SqlSchema::GetTableTypesSchema() { - return arrow::schema({field("table_type", utf8())}); + return arrow::schema({field("table_type", utf8(), false)}); } std::shared_ptr SqlSchema::GetPrimaryKeysSchema() { - return arrow::schema({field("catalog_name", utf8()), field("db_schema_name", utf8()), - field("table_name", utf8()), field("column_name", utf8()), - field("key_sequence", int64()), field("key_name", utf8())}); + return arrow::schema( + {field("catalog_name", utf8()), field("db_schema_name", utf8()), + field("table_name", utf8(), false), field("column_name", utf8(), false), + field("key_sequence", int32(), false), field("key_name", utf8())}); } std::shared_ptr GetImportedExportedKeysAndCrossReferenceSchema() { @@ -742,7 +745,7 @@ std::shared_ptr SqlSchema::GetCrossReferenceSchema() { } std::shared_ptr SqlSchema::GetSqlInfoSchema() { - return arrow::schema({field("name", uint32(), false), + return arrow::schema({field("info_name", uint32(), false), field("value", dense_union({field("string_value", utf8(), false), field("bool_value", boolean(), false), diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index 8dfea7a013e01..ab781e1645f86 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -609,7 +609,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetPrimaryKeys) { const auto key_name = ArrayFromJSON(utf8(), R"([null])"); const auto table_name = ArrayFromJSON(utf8(), R"(["intTable"])"); const auto column_name = ArrayFromJSON(utf8(), R"(["id"])"); - const auto key_sequence = ArrayFromJSON(int64(), R"([1])"); + const auto key_sequence = ArrayFromJSON(int32(), R"([1])"); const std::shared_ptr& expected_table = Table::Make( SqlSchema::GetPrimaryKeysSchema(), diff --git a/cpp/src/arrow/flight/sql/test_server_cli.cc b/cpp/src/arrow/flight/sql/test_server_cli.cc index b696347359e88..e0ba5340e8d94 100644 --- a/cpp/src/arrow/flight/sql/test_server_cli.cc +++ b/cpp/src/arrow/flight/sql/test_server_cli.cc @@ -17,10 +17,10 @@ #include +#include #include #include #include -#include #include "arrow/flight/server.h" #include "arrow/flight/sql/example/sqlite_server.h" diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index 96ebd48912b1c..daaed7d5743c7 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -382,6 +382,9 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, description="Ensure headers are propagated via middleware.", skip={"Rust"} # TODO(ARROW-10961): tonic upgrade needed ), + Scenario( + "flight_sql", + description="Ensure that Flight SQL protocol is working as expected."), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java new file mode 100644 index 0000000000000..3ee5e23ab50cf --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.util.Arrays; + +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Integration test scenario for validating Flight SQL specs across multiple implementations. + * This should ensure that RPC objects are being built and parsed correctly for multiple languages + * and that the Arrow schemas are returned as expected. + */ +public class FlightSqlScenario implements Scenario { + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { + return new FlightSqlScenarioProducer(allocator); + } + + @Override + public void buildServer(FlightServer.Builder builder) throws Exception { + + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) + throws Exception { + final FlightSqlClient sqlClient = new FlightSqlClient(client); + + validateMetadataRetrieval(sqlClient); + + validateStatementExecution(sqlClient); + + validatePreparedStatementExecution(sqlClient); + } + + private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { + final CallOption[] options = new CallOption[0]; + + validate(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, sqlClient.getCatalogs(options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA, + sqlClient.getSchemas("catalog", "db_schema_filter_pattern", options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA, + sqlClient.getTables("catalog", "db_schema_filter_pattern", "table_filter_pattern", + Arrays.asList("table", "view"), true, options), sqlClient); + validate(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA, sqlClient.getTableTypes(options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_PRIMARY_KEYS_SCHEMA, + sqlClient.getPrimaryKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_EXPORTED_KEYS_SCHEMA, + sqlClient.getExportedKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_IMPORTED_KEYS_SCHEMA, + sqlClient.getImportedKeys(TableRef.of("catalog", "db_schema", "table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_CROSS_REFERENCE_SCHEMA, + sqlClient.getCrossReference(TableRef.of("pk_catalog", "pk_db_schema", "pk_table"), + TableRef.of("fk_catalog", "fk_db_schema", "fk_table"), options), + sqlClient); + validate(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, + sqlClient.getSqlInfo(new FlightSql.SqlInfo[] {FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY}, options), sqlClient); + } + + private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { + final CallOption[] options = new CallOption[0]; + + validate(FlightSqlScenarioProducer.getQuerySchema(), + sqlClient.execute("SELECT STATEMENT", options), sqlClient); + + IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), 10000L); + } + + private void validatePreparedStatementExecution(FlightSqlClient sqlClient) throws Exception { + final CallOption[] options = new CallOption[0]; + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "SELECT PREPARED STATEMENT")) { + validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(options), + sqlClient); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "UPDATE PREPARED STATEMENT")) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(options), 20000L); + } + } + + private void validate(Schema expectedSchema, FlightInfo flightInfo, + FlightSqlClient sqlClient) throws Exception { + Ticket ticket = flightInfo.getEndpoints().get(0).getTicket(); + try (FlightStream stream = sqlClient.getStream(ticket)) { + Schema actualSchema = stream.getSchema(); + IntegrationAssertions.assertEquals(expectedSchema, actualSchema); + } + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java new file mode 100644 index 0000000000000..1b76f64431ad2 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import static com.google.protobuf.Any.pack; +import static java.util.Collections.singletonList; + +import java.util.List; + +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; + +/** + * Hardcoded Flight SQL producer used for cross-language integration tests. + */ +public class FlightSqlScenarioProducer implements FlightSqlProducer { + private final BufferAllocator allocator; + + public FlightSqlScenarioProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + static Schema getQuerySchema() { + return new Schema( + singletonList( + new Field("id", FieldType.nullable(new ArrowType.Int(64, true)), null) + ) + ); + } + + @Override + public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, + CallContext context, StreamListener listener) { + IntegrationAssertions.assertTrue("Expect to be one of the two queries used on tests", + request.getQuery().equals("SELECT PREPARED STATEMENT") || + request.getQuery().equals("UPDATE PREPARED STATEMENT")); + + final FlightSql.ActionCreatePreparedStatementResult + result = FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(ByteString.copyFromUtf8(request.getQuery() + " HANDLE")) + .build(); + listener.onNext(new Result(pack(result).toByteArray())); + listener.onCompleted(); + } + + @Override + public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, + CallContext context, StreamListener listener) { + IntegrationAssertions.assertTrue("Expect to be one of the two queries used on tests", + request.getPreparedStatementHandle().toStringUtf8().equals("SELECT PREPARED STATEMENT HANDLE") || + request.getPreparedStatementHandle().toStringUtf8().equals("UPDATE PREPARED STATEMENT HANDLE")); + + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); + + ByteString handle = ByteString.copyFromUtf8("SELECT STATEMENT HANDLE"); + + FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() + .setStatementHandle(handle) + .build(); + return getFlightInfoForSchema(ticket, descriptor, getQuerySchema()); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "SELECT PREPARED STATEMENT HANDLE"); + + return getFlightInfoForSchema(command, descriptor, getQuerySchema()); + } + + @Override + public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + return new SchemaResult(getQuerySchema()); + } + + @Override + public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context, + ServerStreamListener listener) { + serveJsonToStreamListener(listener, getQuerySchema()); + } + + @Override + public void getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, ServerStreamListener listener) { + serveJsonToStreamListener(listener, getQuerySchema()); + } + + private Runnable acceptPutReturnConstant(StreamListener ackStream, int value) { + return () -> { + final FlightSql.DoPutUpdateResult build = + FlightSql.DoPutUpdateResult.newBuilder().setRecordCount(value).build(); + + try (final ArrowBuf buffer = allocator.buffer(build.getSerializedSize())) { + buffer.writeBytes(build.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + ackStream.onCompleted(); + } + }; + } + + @Override + public Runnable acceptPutStatement(FlightSql.CommandStatementUpdate command, CallContext context, + FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getQuery(), "UPDATE STATEMENT"); + + return acceptPutReturnConstant(ackStream, 10000); + } + + @Override + public Runnable acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatementUpdate command, + CallContext context, FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), + "UPDATE PREPARED STATEMENT HANDLE"); + + return acceptPutReturnConstant(ackStream, 20000); + } + + @Override + public Runnable acceptPutPreparedStatementQuery(FlightSql.CommandPreparedStatementQuery command, + CallContext context, FlightStream flightStream, + StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle(), + "SELECT PREPARED STATEMENT HANDLE"); + + return null; + } + + @Override + public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getInfoCount(), 2); + IntegrationAssertions.assertEquals(request.getInfo(0), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); + IntegrationAssertions.assertEquals(request.getInfo(1), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); + } + + @Override + public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command, CallContext context, + ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoCatalogs(FlightSql.CommandGetCatalogs request, CallContext context, + FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_CATALOGS_SCHEMA); + } + + private void serveJsonToStreamListener(ServerStreamListener stream, Schema schema) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + stream.start(root); + stream.putNext(); + stream.completed(); + } + } + + @Override + public void getStreamCatalogs(CallContext context, ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_CATALOGS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoSchemas(FlightSql.CommandGetDbSchemas request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchemaFilterPattern(), + "db_schema_filter_pattern"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_SCHEMAS_SCHEMA); + } + + @Override + public void getStreamSchemas(FlightSql.CommandGetDbSchemas command, CallContext context, + ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_SCHEMAS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTables(FlightSql.CommandGetTables request, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchemaFilterPattern(), + "db_schema_filter_pattern"); + IntegrationAssertions.assertEquals(request.getTableNameFilterPattern(), "table_filter_pattern"); + IntegrationAssertions.assertEquals(request.getTableTypesCount(), 2); + IntegrationAssertions.assertEquals(request.getTableTypes(0), "table"); + IntegrationAssertions.assertEquals(request.getTableTypes(1), "view"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLES_SCHEMA); + } + + @Override + public void getStreamTables(FlightSql.CommandGetTables command, CallContext context, + ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_TABLES_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTableTypes(FlightSql.CommandGetTableTypes request, + CallContext context, FlightDescriptor descriptor) { + return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLE_TYPES_SCHEMA); + } + + @Override + public void getStreamTableTypes(CallContext context, ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_TABLE_TYPES_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys(FlightSql.CommandGetPrimaryKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_PRIMARY_KEYS_SCHEMA); + } + + @Override + public void getStreamPrimaryKeys(FlightSql.CommandGetPrimaryKeys command, CallContext context, + ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_PRIMARY_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoExportedKeys(FlightSql.CommandGetExportedKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_EXPORTED_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoImportedKeys(FlightSql.CommandGetImportedKeys request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getCatalog(), "catalog"); + IntegrationAssertions.assertEquals(request.getDbSchema(), "db_schema"); + IntegrationAssertions.assertEquals(request.getTable(), "table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_IMPORTED_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoCrossReference(FlightSql.CommandGetCrossReference request, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(request.getPkCatalog(), "pk_catalog"); + IntegrationAssertions.assertEquals(request.getPkDbSchema(), "pk_db_schema"); + IntegrationAssertions.assertEquals(request.getPkTable(), "pk_table"); + IntegrationAssertions.assertEquals(request.getFkCatalog(), "fk_catalog"); + IntegrationAssertions.assertEquals(request.getFkDbSchema(), "fk_db_schema"); + IntegrationAssertions.assertEquals(request.getFkTable(), "fk_table"); + + return getFlightInfoForSchema(request, descriptor, Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + @Override + public void getStreamExportedKeys(FlightSql.CommandGetExportedKeys command, CallContext context, + ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_EXPORTED_KEYS_SCHEMA); + } + + @Override + public void getStreamImportedKeys(FlightSql.CommandGetImportedKeys command, CallContext context, + ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_IMPORTED_KEYS_SCHEMA); + } + + @Override + public void getStreamCrossReference(FlightSql.CommandGetCrossReference command, + CallContext context, ServerStreamListener listener) { + serveJsonToStreamListener(listener, Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + @Override + public void close() throws Exception { + + } + + @Override + public void listFlights(CallContext context, Criteria criteria, + StreamListener listener) { + + } + + private FlightInfo getFlightInfoForSchema(final T request, + final FlightDescriptor descriptor, + final Schema schema) { + final Ticket ticket = new Ticket(pack(request).toByteArray()); + final List endpoints = singletonList(new FlightEndpoint(ticket)); + + return new FlightInfo(schema, descriptor, endpoints, -1, -1); + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java index 993ce73f7fe04..e124ed0ea74c7 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java @@ -63,6 +63,15 @@ static void assertFalse(String message, boolean value) { } } + /** + * Assert that the value is true, using the given message as an error otherwise. + */ + static void assertTrue(String message, boolean value) { + if (!value) { + throw new AssertionError("Expected true: " + message); + } + } + /** * An interface used with {@link #assertThrows(Class, AssertThrows)}. */ diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java index 5ed8e70ea63f1..2a36747b61880 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationTestClient.java @@ -91,7 +91,7 @@ private void run(String[] args) throws Exception { final Location defaultLocation = Location.forGrpcInsecure(host, port); try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { + final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { if (cmd.hasOption("scenario")) { Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client); @@ -109,7 +109,7 @@ private static void testStream(BufferAllocator allocator, Location server, Fligh // 1. Read data from JSON and upload to server. FlightDescriptor descriptor = FlightDescriptor.path(inputPath); try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); - VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader, new AsyncPutListener() { int counter = 0; @@ -157,10 +157,10 @@ public void onNext(PutResult val) { for (Location location : locations) { System.out.println("Verifying location " + location.getUri()); try (FlightClient readClient = FlightClient.builder(allocator, location).build(); - FlightStream stream = readClient.getStream(endpoint.getTicket()); - VectorSchemaRoot root = stream.getRoot(); - VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); - JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) { + FlightStream stream = readClient.getStream(endpoint.getTicket()); + VectorSchemaRoot root = stream.getRoot(); + VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) { VectorLoader loader = new VectorLoader(downloadedRoot); VectorUnloader unloader = new VectorUnloader(root); diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index 10882ee0a8603..16cc856daf567 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -41,6 +41,7 @@ private Scenarios() { scenarios = new TreeMap<>(); scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new); scenarios.put("middleware", MiddlewareScenario::new); + scenarios.put("flight_sql", FlightSqlScenario::new); } private static Scenarios getInstance() { diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java index 87c8b3e092dba..f1eaf2f8988de 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -62,7 +62,6 @@ import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; import org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; -import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -595,13 +594,13 @@ void getStreamCrossReference(CommandGetCrossReference command, CallContext conte final class Schemas { public static final Schema GET_TABLES_SCHEMA = new Schema(asList( Field.nullable("catalog_name", VARCHAR.getType()), - Field.nullable("schema_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), Field.notNullable("table_name", VARCHAR.getType()), Field.notNullable("table_type", VARCHAR.getType()), Field.notNullable("table_schema", MinorType.VARBINARY.getType()))); public static final Schema GET_TABLES_SCHEMA_NO_SCHEMA = new Schema(asList( Field.nullable("catalog_name", VARCHAR.getType()), - Field.nullable("schema_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), Field.notNullable("table_name", VARCHAR.getType()), Field.notNullable("table_type", VARCHAR.getType()))); public static final Schema GET_CATALOGS_SCHEMA = new Schema( @@ -611,15 +610,15 @@ final class Schemas { public static final Schema GET_SCHEMAS_SCHEMA = new Schema(asList( Field.nullable("catalog_name", VARCHAR.getType()), - Field.notNullable("schema_name", VARCHAR.getType()))); + Field.notNullable("db_schema_name", VARCHAR.getType()))); private static final Schema GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA = new Schema(asList( Field.nullable("pk_catalog_name", VARCHAR.getType()), - Field.nullable("pk_schema_name", VARCHAR.getType()), + Field.nullable("pk_db_schema_name", VARCHAR.getType()), Field.notNullable("pk_table_name", VARCHAR.getType()), Field.notNullable("pk_column_name", VARCHAR.getType()), Field.nullable("fk_catalog_name", VARCHAR.getType()), - Field.nullable("fk_schema_name", VARCHAR.getType()), + Field.nullable("fk_db_schema_name", VARCHAR.getType()), Field.notNullable("fk_table_name", VARCHAR.getType()), Field.notNullable("fk_column_name", VARCHAR.getType()), Field.notNullable("key_sequence", INT.getType()), @@ -631,32 +630,32 @@ final class Schemas { public static final Schema GET_EXPORTED_KEYS_SCHEMA = GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA; public static final Schema GET_CROSS_REFERENCE_SCHEMA = GET_IMPORTED_EXPORTED_AND_CROSS_REFERENCE_KEYS_SCHEMA; private static final List GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS = asList( - Field.nullable("string_value", VARCHAR.getType()), - Field.nullable("bool_value", BIT.getType()), - Field.nullable("bigint_value", BIGINT.getType()), - Field.nullable("int32_bitmask", INT.getType()), + Field.notNullable("string_value", VARCHAR.getType()), + Field.notNullable("bool_value", BIT.getType()), + Field.notNullable("bigint_value", BIGINT.getType()), + Field.notNullable("int32_bitmask", INT.getType()), new Field( - "string_list", FieldType.nullable(LIST.getType()), - singletonList(Field.nullable(ListVector.DATA_VECTOR_NAME, VARCHAR.getType()))), + "string_list", FieldType.notNullable(LIST.getType()), + singletonList(Field.nullable("item", VARCHAR.getType()))), new Field( - "int32_to_int32_list_map", FieldType.nullable(new ArrowType.Map(false)), + "int32_to_int32_list_map", FieldType.notNullable(new ArrowType.Map(false)), singletonList(new Field(DATA_VECTOR_NAME, new FieldType(false, STRUCT.getType(), null), ImmutableList.of( Field.notNullable(KEY_NAME, INT.getType()), new Field( VALUE_NAME, FieldType.nullable(LIST.getType()), - singletonList(Field.nullable(ListVector.DATA_VECTOR_NAME, INT.getType())))))))); + singletonList(Field.nullable("item", INT.getType())))))))); public static final Schema GET_SQL_INFO_SCHEMA = new Schema(asList( Field.notNullable("info_name", UINT4.getType()), new Field("value", - FieldType.nullable( + FieldType.notNullable( new Union(UnionMode.Dense, range(0, GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS.size()).toArray())), GET_SQL_INFO_DENSE_UNION_SCHEMA_FIELDS))); public static final Schema GET_PRIMARY_KEYS_SCHEMA = new Schema(asList( Field.nullable("catalog_name", VARCHAR.getType()), - Field.nullable("schema_name", VARCHAR.getType()), + Field.nullable("db_schema_name", VARCHAR.getType()), Field.notNullable("table_name", VARCHAR.getType()), Field.notNullable("column_name", VARCHAR.getType()), Field.notNullable("key_sequence", INT.getType()), diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index 687840386e960..634343c236c53 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -356,7 +356,7 @@ private static VectorSchemaRoot getSchemasRoot(final ResultSet data, final Buffe throws SQLException { final VarCharVector catalogs = new VarCharVector("catalog_name", allocator); final VarCharVector schemas = - new VarCharVector("schema_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); + new VarCharVector("db_schema_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); final List vectors = ImmutableList.of(catalogs, schemas); vectors.forEach(FieldVector::allocateNew); final Map vectorToColumnName = ImmutableMap.of( @@ -449,7 +449,7 @@ private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMet */ Objects.requireNonNull(allocator, "BufferAllocator cannot be null."); final VarCharVector catalogNameVector = new VarCharVector("catalog_name", allocator); - final VarCharVector schemaNameVector = new VarCharVector("schema_name", allocator); + final VarCharVector schemaNameVector = new VarCharVector("db_schema_name", allocator); final VarCharVector tableNameVector = new VarCharVector("table_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); final VarCharVector tableTypeVector = @@ -1409,7 +1409,7 @@ public void getStreamPrimaryKeys(final CommandGetPrimaryKeys command, final Call final ResultSet primaryKeys = connection.getMetaData().getPrimaryKeys(catalog, schema, table); final VarCharVector catalogNameVector = new VarCharVector("catalog_name", rootAllocator); - final VarCharVector schemaNameVector = new VarCharVector("schema_name", rootAllocator); + final VarCharVector schemaNameVector = new VarCharVector("db_schema_name", rootAllocator); final VarCharVector tableNameVector = new VarCharVector("table_name", rootAllocator); final VarCharVector columnNameVector = new VarCharVector("column_name", rootAllocator); final IntVector keySequenceVector = new IntVector("key_sequence", rootAllocator); @@ -1527,11 +1527,11 @@ public void getStreamCrossReference(CommandGetCrossReference command, CallContex private VectorSchemaRoot createVectors(ResultSet keys) throws SQLException { final VarCharVector pkCatalogNameVector = new VarCharVector("pk_catalog_name", rootAllocator); - final VarCharVector pkSchemaNameVector = new VarCharVector("pk_schema_name", rootAllocator); + final VarCharVector pkSchemaNameVector = new VarCharVector("pk_db_schema_name", rootAllocator); final VarCharVector pkTableNameVector = new VarCharVector("pk_table_name", rootAllocator); final VarCharVector pkColumnNameVector = new VarCharVector("pk_column_name", rootAllocator); final VarCharVector fkCatalogNameVector = new VarCharVector("fk_catalog_name", rootAllocator); - final VarCharVector fkSchemaNameVector = new VarCharVector("fk_schema_name", rootAllocator); + final VarCharVector fkSchemaNameVector = new VarCharVector("fk_db_schema_name", rootAllocator); final VarCharVector fkTableNameVector = new VarCharVector("fk_table_name", rootAllocator); final VarCharVector fkColumnNameVector = new VarCharVector("fk_column_name", rootAllocator); final IntVector keySequenceVector = new IntVector("key_sequence", rootAllocator); From 163fa77cbecb65042bc0a311ec1d8427175be74a Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Fri, 17 Dec 2021 22:25:47 -0300 Subject: [PATCH 3/8] Implement validation for prepared statement's parameter binding on integration tests --- cpp/src/arrow/CMakeLists.txt | 2 +- .../integration_tests/test_integration.cc | 22 +++++++++--- .../integration_tests/test_integration.h | 2 ++ .../integration/tests/FlightSqlScenario.java | 16 ++++++--- .../tests/FlightSqlScenarioProducer.java | 34 +++++++++++-------- 5 files changed, 53 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index aeb1e51337eaf..857d043ba9cab 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -736,7 +736,7 @@ if(ARROW_FLIGHT_SQL) add_subdirectory(flight/sql) endif() -if(ARROW_FLIGHT AND ARROW_BUILD_INTEGRATION) +if(ARROW_FLIGHT AND ARROW_FLIGHT_SQL AND ARROW_BUILD_INTEGRATION) add_subdirectory(flight/integration_tests) endif() diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 4456befbcf738..c060c322b61dd 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -262,6 +262,8 @@ class MiddlewareScenario : public Scenario { std::shared_ptr client_middleware_; }; +/// \brief Schema to be returned for mocking the statement/prepared statement results. +/// Must be the same across all languages. std::shared_ptr GetQuerySchema() { return arrow::schema({arrow::field("id", int64())}); } @@ -492,7 +494,13 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const sql::PreparedStatementQuery& command, FlightMessageReader* reader, FlightMetadataWriter* writer) override { - return Status::NotImplemented("Not implemented"); + ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", + command.prepared_statement_handle)); + + ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + ARROW_RETURN_NOT_OK(AssertEq(*GetQuerySchema(), *actual_schema)); + + return Status::OK(); } arrow::Result DoPutPreparedStatementUpdate( @@ -515,9 +523,9 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { } arrow::Result> DoGetForTestCase( - std::shared_ptr schema) { - ARROW_ASSIGN_OR_RAISE(auto reader2, RecordBatchReader::Make({}, schema)); - return std::unique_ptr(new RecordBatchStream(reader2)); + const std::shared_ptr& schema) { + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({}, schema)); + return std::unique_ptr(new RecordBatchStream(reader)); } }; @@ -630,8 +638,13 @@ class FlightSqlScenario : public Scenario { ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement, sql_client->Prepare(options, "SELECT PREPARED STATEMENT")); + + auto parameters = RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); + ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters)); + ARROW_RETURN_NOT_OK( Validate(GetQuerySchema(), select_prepared_statement->Execute(), sql_client)); + ARROW_RETURN_NOT_OK(select_prepared_statement->Close()); ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement, sql_client->Prepare(options, "UPDATE PREPARED STATEMENT")); @@ -641,6 +654,7 @@ class FlightSqlScenario : public Scenario { return Status::Invalid("Expected 'UPDATE STATEMENT' return 20000, got ", update_prepared_statement_result); } + ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); return Status::OK(); } diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.h b/cpp/src/arrow/flight/integration_tests/test_integration.h index cc598cff3a2ba..74093f8cd2351 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.h +++ b/cpp/src/arrow/flight/integration_tests/test_integration.h @@ -17,6 +17,8 @@ // Integration test scenarios for Arrow Flight. +#pragma once + #include "arrow/flight/visibility.h" #include diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java index 3ee5e23ab50cf..dc235990fb29e 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -32,6 +32,7 @@ import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.flight.sql.util.TableRef; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; /** @@ -60,7 +61,7 @@ public void client(BufferAllocator allocator, Location location, FlightClient cl validateStatementExecution(sqlClient); - validatePreparedStatementExecution(sqlClient); + validatePreparedStatementExecution(sqlClient, allocator); } private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { @@ -100,13 +101,20 @@ private void validateStatementExecution(FlightSqlClient sqlClient) throws Except validate(FlightSqlScenarioProducer.getQuerySchema(), sqlClient.execute("SELECT STATEMENT", options), sqlClient); - IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), 10000L); + IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), + 10000L); } - private void validatePreparedStatementExecution(FlightSqlClient sqlClient) throws Exception { + private void validatePreparedStatementExecution(FlightSqlClient sqlClient, + BufferAllocator allocator) throws Exception { final CallOption[] options = new CallOption[0]; try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( - "SELECT PREPARED STATEMENT")) { + "SELECT PREPARED STATEMENT"); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(options), sqlClient); } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java index 1b76f64431ad2..cf8a287e9840e 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -54,6 +54,10 @@ public FlightSqlScenarioProducer(BufferAllocator allocator) { this.allocator = allocator; } + /** + * Schema to be returned for mocking the statement/prepared statement results. + * Must be the same across all languages. + */ static Schema getQuerySchema() { return new Schema( singletonList( @@ -119,13 +123,13 @@ public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, @Override public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, getQuerySchema()); + putEmptyBatchToStreamListener(listener, getQuerySchema()); } @Override public void getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, getQuerySchema()); + putEmptyBatchToStreamListener(listener, getQuerySchema()); } private Runnable acceptPutReturnConstant(StreamListener ackStream, int value) { @@ -164,10 +168,12 @@ public Runnable acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatem public Runnable acceptPutPreparedStatementQuery(FlightSql.CommandPreparedStatementQuery command, CallContext context, FlightStream flightStream, StreamListener ackStream) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle(), + IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), "SELECT PREPARED STATEMENT HANDLE"); - return null; + IntegrationAssertions.assertEquals(getQuerySchema(), flightStream.getSchema()); + + return ackStream::onCompleted; } @Override @@ -185,7 +191,7 @@ public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request, Call @Override public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); } @Override @@ -194,7 +200,7 @@ public FlightInfo getFlightInfoCatalogs(FlightSql.CommandGetCatalogs request, Ca return getFlightInfoForSchema(request, descriptor, Schemas.GET_CATALOGS_SCHEMA); } - private void serveJsonToStreamListener(ServerStreamListener stream, Schema schema) { + private void putEmptyBatchToStreamListener(ServerStreamListener stream, Schema schema) { try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { stream.start(root); stream.putNext(); @@ -204,7 +210,7 @@ private void serveJsonToStreamListener(ServerStreamListener stream, Schema schem @Override public void getStreamCatalogs(CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_CATALOGS_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_CATALOGS_SCHEMA); } @Override @@ -220,7 +226,7 @@ public FlightInfo getFlightInfoSchemas(FlightSql.CommandGetDbSchemas request, Ca @Override public void getStreamSchemas(FlightSql.CommandGetDbSchemas command, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_SCHEMAS_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_SCHEMAS_SCHEMA); } @Override @@ -240,7 +246,7 @@ public FlightInfo getFlightInfoTables(FlightSql.CommandGetTables request, CallCo @Override public void getStreamTables(FlightSql.CommandGetTables command, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_TABLES_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_TABLES_SCHEMA); } @Override @@ -251,7 +257,7 @@ public FlightInfo getFlightInfoTableTypes(FlightSql.CommandGetTableTypes request @Override public void getStreamTableTypes(CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_TABLE_TYPES_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_TABLE_TYPES_SCHEMA); } @Override @@ -267,7 +273,7 @@ public FlightInfo getFlightInfoPrimaryKeys(FlightSql.CommandGetPrimaryKeys reque @Override public void getStreamPrimaryKeys(FlightSql.CommandGetPrimaryKeys command, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_PRIMARY_KEYS_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_PRIMARY_KEYS_SCHEMA); } @Override @@ -306,19 +312,19 @@ public FlightInfo getFlightInfoCrossReference(FlightSql.CommandGetCrossReference @Override public void getStreamExportedKeys(FlightSql.CommandGetExportedKeys command, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_EXPORTED_KEYS_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_EXPORTED_KEYS_SCHEMA); } @Override public void getStreamImportedKeys(FlightSql.CommandGetImportedKeys command, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_IMPORTED_KEYS_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_IMPORTED_KEYS_SCHEMA); } @Override public void getStreamCrossReference(FlightSql.CommandGetCrossReference command, CallContext context, ServerStreamListener listener) { - serveJsonToStreamListener(listener, Schemas.GET_CROSS_REFERENCE_SCHEMA); + putEmptyBatchToStreamListener(listener, Schemas.GET_CROSS_REFERENCE_SCHEMA); } @Override From 4d384abefe187b1d0070734a4fdcbe4f57be5bd8 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Mon, 20 Dec 2021 14:18:00 -0300 Subject: [PATCH 4/8] Skip Rust and Go from Flight SQL integration test scenario --- dev/archery/archery/integration/runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index daaed7d5743c7..686183bda72c2 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -384,7 +384,9 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, ), Scenario( "flight_sql", - description="Ensure that Flight SQL protocol is working as expected."), + description="Ensure that Flight SQL protocol is working as expected.", + skip={"Rust", "Go"} + ), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) From f6bbf6cf5722182fb280cbddc25e53dfca54e087 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Mon, 20 Dec 2021 14:20:44 -0300 Subject: [PATCH 5/8] Fix formatting issues on C++ integration tests --- .../flight/integration_tests/test_integration.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index c060c322b61dd..44d5fcffa242a 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -276,9 +276,9 @@ arrow::Status AssertEq(const T& expected, const T& actual) { return Status::OK(); } -/// \brief The server used for testing Flight SQL, this implements a static Flight SQL server which only asserts -/// that commands called during integration tests are being parsed correctly and returns the expected schemas to be -/// validated on client. +/// \brief The server used for testing Flight SQL, this implements a static Flight SQL +/// server which only asserts that commands called during integration tests are being +/// parsed correctly and returns the expected schemas to be validated on client. class FlightSqlScenarioServer : public sql::FlightSqlServerBase { public: arrow::Result> GetFlightInfoStatement( @@ -622,7 +622,7 @@ class FlightSqlScenario : public Scenario { FlightCallOptions options; ARROW_RETURN_NOT_OK(Validate( - GetQuerySchema(), sql_client->Execute(options, "SELECT STATEMENT"), sql_client)); + GetQuerySchema(), sql_client->Execute(options, "SELECT STATEMENT"), sql_client)); ARROW_ASSIGN_OR_RAISE(auto update_statement_result, sql_client->ExecuteUpdate(options, "UPDATE STATEMENT")); if (update_statement_result != 10000L) { @@ -639,11 +639,12 @@ class FlightSqlScenario : public Scenario { ARROW_ASSIGN_OR_RAISE(auto select_prepared_statement, sql_client->Prepare(options, "SELECT PREPARED STATEMENT")); - auto parameters = RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); + auto parameters = + RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters)); ARROW_RETURN_NOT_OK( - Validate(GetQuerySchema(), select_prepared_statement->Execute(), sql_client)); + Validate(GetQuerySchema(), select_prepared_statement->Execute(), sql_client)); ARROW_RETURN_NOT_OK(select_prepared_statement->Close()); ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement, From 97fe3ce48166918cb837ba7586e91c4f1cfbd038 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Mon, 20 Dec 2021 15:13:38 -0300 Subject: [PATCH 6/8] Fix TestFlightSqlServer#TestCommandGetSqlInfoNoInfo failure on CI --- cpp/src/arrow/CMakeLists.txt | 4 +++- cpp/src/arrow/flight/sql/server_test.cc | 2 +- dev/archery/archery/integration/runner.py | 2 +- dev/archery/archery/integration/tester_java.py | 5 +++-- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 857d043ba9cab..cc979a22e09fa 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -736,7 +736,9 @@ if(ARROW_FLIGHT_SQL) add_subdirectory(flight/sql) endif() -if(ARROW_FLIGHT AND ARROW_FLIGHT_SQL AND ARROW_BUILD_INTEGRATION) +if(ARROW_FLIGHT + AND ARROW_FLIGHT_SQL + AND ARROW_BUILD_INTEGRATION) add_subdirectory(flight/integration_tests) endif() diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index ab781e1645f86..d74b6d4013748 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -758,7 +758,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetSqlInfoNoInfo) { ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetSqlInfo(call_options, {999999})); EXPECT_RAISES_WITH_MESSAGE_THAT( - KeyError, ::testing::HasSubstr("No information for SQL info number 999999."), + KeyError, ::testing::HasSubstr("No information for SQL info number 999999"), sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); } diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index 686183bda72c2..74bbed1fc4fa9 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -384,7 +384,7 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, ), Scenario( "flight_sql", - description="Ensure that Flight SQL protocol is working as expected.", + description="Ensure Flight SQL protocol is working as expected.", skip={"Rust", "Go"} ), ] diff --git a/dev/archery/archery/integration/tester_java.py b/dev/archery/archery/integration/tester_java.py index 75875ad7185c3..69c6e54e0562d 100644 --- a/dev/archery/archery/integration/tester_java.py +++ b/dev/archery/archery/integration/tester_java.py @@ -49,8 +49,9 @@ class JavaTester(Tester): ARROW_FLIGHT_JAR = os.environ.get( 'ARROW_FLIGHT_JAVA_INTEGRATION_JAR', os.path.join(ARROW_ROOT_DEFAULT, - 'java/flight/flight-integration-tests/target/flight-integration-tests-{}-' - 'jar-with-dependencies.jar'.format(_arrow_version))) + 'java/flight/flight-integration-tests/target/' + 'flight-integration-tests-{}-jar-with-dependencies.jar' + .format(_arrow_version))) ARROW_FLIGHT_SERVER = ('org.apache.arrow.flight.integration.tests.' 'IntegrationTestServer') ARROW_FLIGHT_CLIENT = ('org.apache.arrow.flight.integration.tests.' From f2731569be1816816a1670ee5423f6da8e53a883 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Tue, 21 Dec 2021 11:45:09 -0300 Subject: [PATCH 7/8] nit: Extract expected results for update statements as constants (C++ and Java) --- .../integration_tests/test_integration.cc | 17 +++++++++++------ .../integration/tests/FlightSqlScenario.java | 8 ++++++-- .../tests/FlightSqlScenarioProducer.java | 6 +++--- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 44d5fcffa242a..50de6b534ceec 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -268,6 +268,9 @@ std::shared_ptr GetQuerySchema() { return arrow::schema({arrow::field("id", int64())}); } +constexpr int64_t kUpdateStatementExpectedRows = 10000L; +constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L; + template arrow::Status AssertEq(const T& expected, const T& actual) { if (expected != actual) { @@ -468,7 +471,7 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const ServerCallContext& context, const sql::StatementUpdate& command) override { ARROW_RETURN_NOT_OK(AssertEq("UPDATE STATEMENT", command.query)); - return 10000; + return kUpdateStatementExpectedRows; } arrow::Result CreatePreparedStatement( @@ -509,7 +512,7 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { ARROW_RETURN_NOT_OK(AssertEq("UPDATE PREPARED STATEMENT HANDLE", command.prepared_statement_handle)); - return 20000; + return kUpdatePreparedStatementExpectedRows; } private: @@ -625,8 +628,9 @@ class FlightSqlScenario : public Scenario { GetQuerySchema(), sql_client->Execute(options, "SELECT STATEMENT"), sql_client)); ARROW_ASSIGN_OR_RAISE(auto update_statement_result, sql_client->ExecuteUpdate(options, "UPDATE STATEMENT")); - if (update_statement_result != 10000L) { - return Status::Invalid("Expected 'UPDATE STATEMENT' return 10000, got ", + if (update_statement_result != kUpdateStatementExpectedRows) { + return Status::Invalid("Expected 'UPDATE STATEMENT' return ", + kUpdateStatementExpectedRows, ", got ", update_statement_result); } @@ -651,8 +655,9 @@ class FlightSqlScenario : public Scenario { sql_client->Prepare(options, "UPDATE PREPARED STATEMENT")); ARROW_ASSIGN_OR_RAISE(auto update_prepared_statement_result, update_prepared_statement->ExecuteUpdate()); - if (update_prepared_statement_result != 20000L) { - return Status::Invalid("Expected 'UPDATE STATEMENT' return 20000, got ", + if (update_prepared_statement_result != kUpdatePreparedStatementExpectedRows) { + return Status::Invalid("Expected 'UPDATE STATEMENT' return ", + kUpdatePreparedStatementExpectedRows, ", got ", update_prepared_statement_result); } ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java index dc235990fb29e..374e634e8a345 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -42,6 +42,9 @@ */ public class FlightSqlScenario implements Scenario { + public static final long UPDATE_STATEMENT_EXPECTED_ROWS = 10000L; + public static final long UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS = 20000L; + @Override public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { return new FlightSqlScenarioProducer(allocator); @@ -102,7 +105,7 @@ private void validateStatementExecution(FlightSqlClient sqlClient) throws Except sqlClient.execute("SELECT STATEMENT", options), sqlClient); IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), - 10000L); + UPDATE_STATEMENT_EXPECTED_ROWS); } private void validatePreparedStatementExecution(FlightSqlClient sqlClient, @@ -121,7 +124,8 @@ private void validatePreparedStatementExecution(FlightSqlClient sqlClient, try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( "UPDATE PREPARED STATEMENT")) { - IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(options), 20000L); + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(options), + UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); } } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java index cf8a287e9840e..f3554e1d3d83c 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -132,7 +132,7 @@ public void getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery c putEmptyBatchToStreamListener(listener, getQuerySchema()); } - private Runnable acceptPutReturnConstant(StreamListener ackStream, int value) { + private Runnable acceptPutReturnConstant(StreamListener ackStream, long value) { return () -> { final FlightSql.DoPutUpdateResult build = FlightSql.DoPutUpdateResult.newBuilder().setRecordCount(value).build(); @@ -151,7 +151,7 @@ public Runnable acceptPutStatement(FlightSql.CommandStatementUpdate command, Cal StreamListener ackStream) { IntegrationAssertions.assertEquals(command.getQuery(), "UPDATE STATEMENT"); - return acceptPutReturnConstant(ackStream, 10000); + return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS); } @Override @@ -161,7 +161,7 @@ public Runnable acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatem IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), "UPDATE PREPARED STATEMENT HANDLE"); - return acceptPutReturnConstant(ackStream, 20000); + return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); } @Override From 4b21e94539058795838c9a90365c58d18e8e1e2f Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Tue, 21 Dec 2021 16:08:03 -0300 Subject: [PATCH 8/8] nit: Use quotes on arrow includes on test_integration.cc --- .../flight/integration_tests/test_integration.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 50de6b534ceec..1e08f47b579bd 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -23,9 +23,8 @@ #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" +#include "arrow/testing/gtest_util.h" -#include -#include #include #include #include @@ -544,18 +543,18 @@ class FlightSqlScenario : public Scenario { Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } - Status Validate(std::shared_ptr expectedSchema, - arrow::Result> flightInfo, + Status Validate(std::shared_ptr expected_schema, + arrow::Result> flight_info_result, sql::FlightSqlClient* sql_client) { FlightCallOptions call_options; - ARROW_ASSIGN_OR_RAISE(auto flight_info, flightInfo); + ARROW_ASSIGN_OR_RAISE(auto flight_info, flight_info_result); ARROW_ASSIGN_OR_RAISE( auto reader, sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); - AssertSchemaEqual(expectedSchema, actual_schema); + AssertSchemaEqual(expected_schema, actual_schema); return Status::OK(); }