diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 8c4388fc0f9..70332773302 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -474,6 +474,10 @@ jobs: PIPX_BASE_PYTHON: ${{ steps.python-install.outputs.python-path }} run: | ci/scripts/install_gcs_testbench.sh default + - name: Register Flight SQL ODBC Driver + shell: cmd + run: | + call "cpp\src\arrow\flight\sql\odbc\install\install_amd64.cmd" ${{github.workspace}}\build\cpp\%ARROW_BUILD_TYPE%\libarrow_flight_sql_odbc.dll - name: Test shell: msys2 {0} run: | diff --git a/ci/scripts/cpp_test.sh b/ci/scripts/cpp_test.sh index 05885ce4018..e646ba964a5 100755 --- a/ci/scripts/cpp_test.sh +++ b/ci/scripts/cpp_test.sh @@ -73,8 +73,6 @@ case "$(uname)" in exclude_tests="${exclude_tests}|gandiva-precompiled-test" exclude_tests="${exclude_tests}|gandiva-projector-test" exclude_tests="${exclude_tests}|gandiva-utf8-test" - # TODO: Enable ODBC tests - exclude_tests="${exclude_tests}|arrow-connection-test" ctest_options+=(--exclude-regex "${exclude_tests}") ;; *) diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 581c30e1aaa..05957d4b275 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -108,7 +108,7 @@ endmacro() macro(resolve_option_dependencies) # Arrow Flight SQL ODBC is available only for Windows for now. - if(NOT MSVC_TOOLCHAIN) + if(NOT WIN32) set(ARROW_FLIGHT_SQL_ODBC OFF) endif() if(MSVC_TOOLCHAIN) diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc index 3fcc3a87162..4b66c30dab0 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/flight_sql_auth_method.cc @@ -45,6 +45,10 @@ class NoOpAuthMethod : public FlightSqlAuthMethod { void Authenticate(FlightSqlConnection& connection, FlightCallOptions& call_options) override { // Do nothing + + // TODO: implement NoOpAuthMethod to validate server address. + // Can use NoOpClientAuthHandler. + // https://github.com/apache/arrow/issues/46733 } }; diff --git a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.h b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.h index f5470693eea..f3744d3428a 100644 --- a/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.h +++ b/cpp/src/arrow/flight/sql/odbc/flight_sql/system_dsn.h @@ -40,7 +40,7 @@ bool DisplayConnectionWindow(void* windowParent, Configuration& config); * @param windowParent Parent window handle. * @param config Output configuration, presumed to be empty, it will be using values from * properties. - * @param config Output properties. + * @param properties Output properties. * @return True on success and false on fail. */ bool DisplayConnectionWindow(void* windowParent, Configuration& config, diff --git a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt index 1d0dce0bec4..41e51182275 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt @@ -21,13 +21,25 @@ include_directories(${ODBC_INCLUDE_DIRS}) add_definitions(-DUNICODE=1) +find_package(SQLite3Alt REQUIRED) + +set(ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS + ../../example/sqlite_sql_info.cc + ../../example/sqlite_type_info.cc + ../../example/sqlite_statement.cc + ../../example/sqlite_statement_batch_reader.cc + ../../example/sqlite_server.cc + ../../example/sqlite_tables_schema_batch_reader.cc) + add_arrow_test(connection_test SOURCES connection_test.cc odbc_test_suite.cc odbc_test_suite.h + ${ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS} EXTRA_LINK_LIBS ${ODBC_LIBRARIES} ${ODBCINST} + ${SQLite3_LIBRARIES} arrow_odbc_spi_impl odbcabstraction) diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc index b2866b3e5a7..8461ead6cf9 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc @@ -200,49 +200,51 @@ TEST(SQLSetEnvAttr, TestSQLSetEnvAttrODBCVersionInvalid) { EXPECT_TRUE(return_set == SQL_ERROR); } -TEST_F(FlightSQLODBCTestBase, TestSQLGetEnvAttrOutputNTS) { - connect(); +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetEnvAttrOutputNTS) { + this->connect(); SQLINTEGER output_nts; - SQLRETURN return_get = SQLGetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, &output_nts, 0, 0); + SQLRETURN return_get = SQLGetEnvAttr(this->env, SQL_ATTR_OUTPUT_NTS, &output_nts, 0, 0); EXPECT_TRUE(return_get == SQL_SUCCESS); EXPECT_EQ(output_nts, SQL_TRUE); - disconnect(); + this->disconnect(); } -TEST_F(FlightSQLODBCTestBase, TestSQLGetEnvAttrGetLength) { +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetEnvAttrGetLength) { // Test is disabled because call to SQLGetEnvAttr is handled by the driver manager on - // Windows. This test case can be potentionally used on macOS/Linux + // Windows. This test case can be potentially used on macOS/Linux GTEST_SKIP(); - connect(); + this->connect(); SQLINTEGER length; - SQLRETURN return_get = SQLGetEnvAttr(env, SQL_ATTR_ODBC_VERSION, nullptr, 0, &length); + SQLRETURN return_get = + SQLGetEnvAttr(this->env, SQL_ATTR_ODBC_VERSION, nullptr, 0, &length); EXPECT_TRUE(return_get == SQL_SUCCESS); EXPECT_EQ(length, sizeof(SQLINTEGER)); - disconnect(); + this->disconnect(); } -TEST_F(FlightSQLODBCTestBase, TestSQLGetEnvAttrNullValuePointer) { +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetEnvAttrNullValuePointer) { // Test is disabled because call to SQLGetEnvAttr is handled by the driver manager on - // Windows. This test case can be potentionally used on macOS/Linux + // Windows. This test case can be potentially used on macOS/Linux GTEST_SKIP(); - connect(); + this->connect(); - SQLRETURN return_get = SQLGetEnvAttr(env, SQL_ATTR_ODBC_VERSION, nullptr, 0, nullptr); + SQLRETURN return_get = + SQLGetEnvAttr(this->env, SQL_ATTR_ODBC_VERSION, nullptr, 0, nullptr); EXPECT_TRUE(return_get == SQL_ERROR); - disconnect(); + this->disconnect(); } TEST(SQLSetEnvAttr, TestSQLSetEnvAttrOutputNTSValid) { @@ -292,7 +294,7 @@ TEST(SQLSetEnvAttr, TestSQLSetEnvAttrNullValuePointer) { EXPECT_TRUE(return_set == SQL_ERROR); } -TEST(SQLDriverConnect, TestSQLDriverConnect) { +TYPED_TEST(FlightSQLODBCTestBase, TestSQLDriverConnect) { // ODBC Environment SQLHENV env; SQLHDBC conn; @@ -312,8 +314,7 @@ TEST(SQLDriverConnect, TestSQLDriverConnect) { EXPECT_TRUE(ret == SQL_SUCCESS); // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); + std::string connect_str = this->getConnectionString(); ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, arrow::util::UTF8ToWideString(connect_str)); std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); @@ -361,7 +362,7 @@ TEST(SQLDriverConnect, TestSQLDriverConnect) { EXPECT_TRUE(ret == SQL_SUCCESS); } -TEST(SQLDriverConnect, TestSQLDriverConnectInvalidUid) { +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLDriverConnectInvalidUid) { // ODBC Environment SQLHENV env; SQLHDBC conn; @@ -380,11 +381,8 @@ TEST(SQLDriverConnect, TestSQLDriverConnectInvalidUid) { EXPECT_TRUE(ret == SQL_SUCCESS); - // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); - // Append invalid uid to connection string - connect_str += std::string("uid=non_existent_id;"); + // Invalid connect string + std::string connect_str = getInvalidConnectionString(); ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, arrow::util::UTF8ToWideString(connect_str)); @@ -418,7 +416,7 @@ TEST(SQLDriverConnect, TestSQLDriverConnectInvalidUid) { EXPECT_TRUE(ret == SQL_SUCCESS); } -TEST(SQLConnect, TestSQLConnect) { +TYPED_TEST(FlightSQLODBCTestBase, TestSQLConnect) { // ODBC Environment SQLHENV env; SQLHDBC conn; @@ -438,8 +436,7 @@ TEST(SQLConnect, TestSQLConnect) { EXPECT_TRUE(ret == SQL_SUCCESS); // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); + std::string connect_str = this->getConnectionString(); // Write connection string content into a DSN, // must succeed before continuing @@ -454,7 +451,7 @@ TEST(SQLConnect, TestSQLConnect) { std::vector uid0(wuid.begin(), wuid.end()); std::vector pwd0(wpwd.begin(), wpwd.end()); - // Connecting to ODBC server. + // Connecting to ODBC server. Empty uid and pwd should be ignored. ret = SQLConnect(conn, dsn0.data(), static_cast(dsn0.size()), uid0.data(), static_cast(uid0.size()), pwd0.data(), static_cast(pwd0.size())); @@ -488,7 +485,7 @@ TEST(SQLConnect, TestSQLConnect) { EXPECT_TRUE(ret == SQL_SUCCESS); } -TEST(SQLConnect, TestSQLConnectInputUidPwd) { +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLConnectInputUidPwd) { // ODBC Environment SQLHENV env; SQLHDBC conn; @@ -508,10 +505,9 @@ TEST(SQLConnect, TestSQLConnectInputUidPwd) { EXPECT_TRUE(ret == SQL_SUCCESS); // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); + std::string connect_str = getConnectionString(); - // Retrieve valid uid and pwd + // Retrieve valid uid and pwd, assumes TEST_CONNECT_STR contains uid and pwd Connection::ConnPropertyMap properties; ODBC::ODBCConnection::getPropertiesFromConnString(connect_str, properties); std::string uid_key("uid"); @@ -567,7 +563,7 @@ TEST(SQLConnect, TestSQLConnectInputUidPwd) { EXPECT_TRUE(ret == SQL_SUCCESS); } -TEST(SQLConnect, TestSQLConnectInvalidUid) { +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLConnectInvalidUid) { // ODBC Environment SQLHENV env; SQLHDBC conn; @@ -587,10 +583,9 @@ TEST(SQLConnect, TestSQLConnectInvalidUid) { EXPECT_TRUE(ret == SQL_SUCCESS); // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); + std::string connect_str = getConnectionString(); - // Retrieve valid uid and pwd + // Retrieve valid uid and pwd, assumes TEST_CONNECT_STR contains uid and pwd Connection::ConnPropertyMap properties; ODBC::ODBCConnection::getPropertiesFromConnString(connect_str, properties); std::string uid = properties[std::string("uid")]; @@ -636,7 +631,7 @@ TEST(SQLConnect, TestSQLConnectInvalidUid) { EXPECT_TRUE(ret == SQL_SUCCESS); } -TEST(SQLConnect, TestSQLConnectDSNPrecedence) { +TEST_F(FlightSQLODBCRemoteTestBase, TestSQLConnectDSNPrecedence) { // ODBC Environment SQLHENV env; SQLHDBC conn; @@ -656,13 +651,13 @@ TEST(SQLConnect, TestSQLConnectDSNPrecedence) { EXPECT_TRUE(ret == SQL_SUCCESS); // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); + std::string connect_str = getConnectionString(); // Write connection string content into a DSN, // must succeed before continuing - // Pass incorrect uid and password to SQLConnect, they will be ignored + // Pass incorrect uid and password to SQLConnect, they will be ignored. + // Assumes TEST_CONNECT_STR contains uid and pwd std::string uid("non_existent_id"), pwd("non_existent_password"); ASSERT_TRUE(writeDSN(connect_str)); @@ -746,7 +741,7 @@ TEST(SQLDisconnect, TestSQLDisconnectWithoutConnection) { EXPECT_TRUE(ret == SQL_SUCCESS); } -TEST(SQLGetDiagFieldW, TestSQLGetDiagFieldWForConnectFailure) { +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetDiagFieldWForConnectFailure) { // ODBC Environment SQLHENV env; SQLHDBC conn; @@ -765,11 +760,8 @@ TEST(SQLGetDiagFieldW, TestSQLGetDiagFieldWForConnectFailure) { EXPECT_TRUE(ret == SQL_SUCCESS); - // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); - // Append invalid uid to connection string - connect_str += std::string("uid=non_existent_id;"); + // Invalid connect string + std::string connect_str = this->getInvalidConnectionString(); ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, arrow::util::UTF8ToWideString(connect_str)); @@ -859,9 +851,9 @@ TEST(SQLGetDiagFieldW, TestSQLGetDiagFieldWForConnectFailure) { EXPECT_TRUE(ret == SQL_SUCCESS); } -TEST(SQLGetDiagFieldW, TestSQLGetDiagFieldWForConnectFailureNTS) { +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetDiagFieldWForConnectFailureNTS) { // Test is disabled because driver manager on Windows does not pass through SQL_NTS - // This test case can be potentionally used on macOS/Linux + // This test case can be potentially used on macOS/Linux GTEST_SKIP(); // ODBC Environment SQLHENV env; @@ -881,11 +873,8 @@ TEST(SQLGetDiagFieldW, TestSQLGetDiagFieldWForConnectFailureNTS) { EXPECT_TRUE(ret == SQL_SUCCESS); - // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); - // Append invalid uid to connection string - connect_str += std::string("uid=non_existent_id;"); + // Invalid connect string + std::string connect_str = this->getInvalidConnectionString(); ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, arrow::util::UTF8ToWideString(connect_str)); @@ -902,7 +891,6 @@ TEST(SQLGetDiagFieldW, TestSQLGetDiagFieldWForConnectFailureNTS) { EXPECT_TRUE(ret == SQL_ERROR); // Retrieve all supported header level and record level data - SQLSMALLINT HEADER_LEVEL = 0; SQLSMALLINT RECORD_1 = 1; // SQL_DIAG_MESSAGE_TEXT SQL_NTS @@ -929,7 +917,7 @@ TEST(SQLGetDiagFieldW, TestSQLGetDiagFieldWForConnectFailureNTS) { EXPECT_TRUE(ret == SQL_SUCCESS); } -TEST(SQLGetDiagRec, TestSQLGetDiagRecForConnectFailure) { +TYPED_TEST(FlightSQLODBCTestBase, TestSQLGetDiagRecForConnectFailure) { // ODBC Environment SQLHENV env; SQLHDBC conn; @@ -948,11 +936,8 @@ TEST(SQLGetDiagRec, TestSQLGetDiagRecForConnectFailure) { EXPECT_TRUE(ret == SQL_SUCCESS); - // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); - // Append invalid uid to connection string - connect_str += std::string("uid=non_existent_id;"); + // Invalid connect string + std::string connect_str = this->getInvalidConnectionString(); ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, arrow::util::UTF8ToWideString(connect_str)); @@ -978,7 +963,7 @@ TEST(SQLGetDiagRec, TestSQLGetDiagRecForConnectFailure) { EXPECT_TRUE(ret == SQL_SUCCESS); - EXPECT_GT(message_length, 200); + EXPECT_GT(message_length, 120); EXPECT_EQ(native_error, 200); @@ -1000,6 +985,12 @@ TEST(SQLGetDiagRec, TestSQLGetDiagRecForConnectFailure) { EXPECT_TRUE(ret == SQL_SUCCESS); } +TYPED_TEST(FlightSQLODBCTestBase, TestConnect) { + // Verifies connect and disconnect works on its own + this->connect(); + this->disconnect(); +} + } // namespace integration_tests } // namespace odbc } // namespace flight diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc index d60bcea19e4..656c221a1d1 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc @@ -29,7 +29,11 @@ namespace arrow { namespace flight { namespace odbc { namespace integration_tests { -void FlightSQLODBCTestBase::connect() { +void FlightSQLODBCRemoteTestBase::connect() { + std::string connect_str = getConnectionString(); + connectWithString(connect_str); +} +void FlightSQLODBCRemoteTestBase::connectWithString(std::string connect_str) { // Allocate an environment handle SQLRETURN ret = SQLAllocEnv(&env); @@ -45,8 +49,6 @@ void FlightSQLODBCTestBase::connect() { EXPECT_TRUE(ret == SQL_SUCCESS); // Connect string - ASSERT_OK_AND_ASSIGN(std::string connect_str, - arrow::internal::GetEnvVar(TEST_CONNECT_STR)); std::vector connect_str0(connect_str.begin(), connect_str.end()); SQLWCHAR outstr[ODBC_BUFFER_SIZE]; @@ -65,7 +67,7 @@ void FlightSQLODBCTestBase::connect() { ASSERT_TRUE(ret == SQL_SUCCESS); } -void FlightSQLODBCTestBase::disconnect() { +void FlightSQLODBCRemoteTestBase::disconnect() { // Disconnect from ODBC SQLRETURN ret = SQLDisconnect(conn); @@ -86,6 +88,93 @@ void FlightSQLODBCTestBase::disconnect() { EXPECT_TRUE(ret == SQL_SUCCESS); } +std::string FlightSQLODBCRemoteTestBase::getConnectionString() { + std::string connect_str = arrow::internal::GetEnvVar(TEST_CONNECT_STR).ValueOrDie(); + return connect_str; +} + +std::string FlightSQLODBCRemoteTestBase::getInvalidConnectionString() { + std::string connect_str = getConnectionString(); + // Append invalid uid to connection string + connect_str += std::string("uid=non_existent_id;"); + return connect_str; +} + +void FlightSQLODBCRemoteTestBase::SetUp() { + if (arrow::internal::GetEnvVar(TEST_CONNECT_STR).ValueOr("").empty()) { + GTEST_SKIP() << "Skipping FlightSQLODBCRemoteTestBase test: TEST_CONNECT_STR not set"; + } +} + +std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (::toupper(char1) == ::toupper(char2)); + }; + + const std::string auth_val(incoming_headers.find(kAuthHeader)->second); + std::string bearer_token(""); + if (auth_val.size() > kBearerPrefix.length()) { + if (std::equal(auth_val.begin(), auth_val.begin() + kBearerPrefix.length(), + kBearerPrefix.begin(), char_compare)) { + bearer_token = auth_val.substr(kBearerPrefix.length()); + } + } + return bearer_token; +} + +void MockServerMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { + std::string bearer_token = FindTokenInCallHeaders(incoming_headers_); + *isValid_ = (bearer_token == std::string(test_token)); +} + +Status MockServerMiddlewareFactory::StartCall( + const CallInfo& info, const ServerCallContext& context, + std::shared_ptr* middleware) { + std::string bearer_token = FindTokenInCallHeaders(context.incoming_headers()); + if (bearer_token == std::string(test_token)) { + *middleware = + std::make_shared(context.incoming_headers(), &isValid_); + } else { + return MakeFlightError(FlightStatusCode::Unauthenticated, + "Invalid token for mock server"); + } + + return Status::OK(); +} + +std::string FlightSQLODBCMockTestBase::getConnectionString() { + std::string connect_str( + "driver={Apache Arrow Flight SQL ODBC Driver};HOST=localhost;port=" + + std::to_string(port) + ";token=" + std::string(test_token) + + ";useEncryption=false;"); + return connect_str; +} + +std::string FlightSQLODBCMockTestBase::getInvalidConnectionString() { + std::string connect_str = getConnectionString(); + // Append invalid token to connection string + connect_str += std::string("token=invalid_token;"); + return connect_str; +} + +void FlightSQLODBCMockTestBase::SetUp() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options(location); + options.auth_handler = std::make_unique(); + options.middleware.push_back( + {"bearer-auth-server", std::make_shared()}); + ASSERT_OK_AND_ASSIGN(server, + arrow::flight::sql::example::SQLiteFlightSqlServer::Create()); + ASSERT_OK(server->Init(options)); + + port = server->port(); + ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", port)); + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location)); +} + +void FlightSQLODBCMockTestBase::TearDown() { ASSERT_OK(server->Shutdown()); } + bool compareConnPropertyMap(Connection::ConnPropertyMap map1, Connection::ConnPropertyMap map2) { if (map1.size() != map2.size()) return false; diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h index 49ab2e20f44..168204e8d0b 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h @@ -19,6 +19,9 @@ #include "arrow/util/io_util.h" #include "arrow/util/utf8.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/example/sqlite_server.h" #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/encoding_utils.h" #ifdef _WIN32 @@ -29,6 +32,8 @@ #include #include +#include + #include "arrow/flight/sql/odbc/odbcabstraction/include/odbcabstraction/odbc_impl/odbc_connection.h" // For DSN registration @@ -43,13 +48,20 @@ namespace odbc { namespace integration_tests { using driver::odbcabstraction::Connection; -class FlightSQLODBCTestBase : public ::testing::Test { +class FlightSQLODBCRemoteTestBase : public ::testing::Test { public: /// \brief Connect to Arrow Flight SQL server using connection string defined in /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" void connect(); + /// \brief Connect to Arrow Flight SQL server using connection string + void connectWithString(std::string connection_str); /// \brief Disconnect from server void disconnect(); + /// \brief Get connection string from environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" + std::string virtual getConnectionString(); + /// \brief Get invalid connection string based on connection string defined in + /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" + std::string virtual getInvalidConnectionString(); /** ODBC Environment. */ SQLHENV env; @@ -59,8 +71,77 @@ class FlightSQLODBCTestBase : public ::testing::Test { /** ODBC Statement. */ SQLHSTMT stmt; + + protected: + void SetUp() override; +}; + +static constexpr std::string_view kAuthHeader = "authorization"; +static constexpr std::string_view kBearerPrefix = "Bearer "; +static constexpr std::string_view test_token = "t0k3n"; + +std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers); + +// A server middleware for validating incoming bearer header authentication. +class MockServerMiddleware : public ServerMiddleware { + public: + explicit MockServerMiddleware(const CallHeaders& incoming_headers, bool* isValid) + : isValid_(isValid) { + incoming_headers_ = incoming_headers; + } + + void SendingHeaders(AddCallHeaders* outgoing_headers) override; + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "MockServerMiddleware"; } + + private: + CallHeaders incoming_headers_; + bool* isValid_; +}; + +// Factory for base64 header authentication testing. +class MockServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + MockServerMiddlewareFactory() : isValid_(false) {} + + Status StartCall(const CallInfo& info, const ServerCallContext& context, + std::shared_ptr* middleware) override; + + private: + bool isValid_; +}; + +class FlightSQLODBCMockTestBase : public FlightSQLODBCRemoteTestBase { + // Sets up a mock server for each test case + public: + /// \brief Get connection string for mock server + std::string getConnectionString() override; + /// \brief Get invalid connection string for mock server + std::string getInvalidConnectionString() override; + + int port; + + protected: + void SetUp() override; + + void TearDown() override; + + private: + std::shared_ptr server; }; +template +class FlightSQLODBCTestBase : public T { + public: + using List = std::list; +}; + +using TestTypes = + ::testing::Types; +TYPED_TEST_SUITE(FlightSQLODBCTestBase, TestTypes); + /** ODBC read buffer size. */ enum { ODBC_BUFFER_SIZE = 1024 };