From f8dd969765ca167c90697aa896a8f76debe8b951 Mon Sep 17 00:00:00 2001 From: Paul Nienaber Date: Tue, 20 Feb 2024 05:49:26 -0800 Subject: [PATCH] GH-34865: [C++][Java][Flight RPC] Add Session management messages (#34817) ### Rationale for this change Flight presently contains no formal mechanism for managing connection/query configuration options; instead, request headers and/or non-query SQL statements are often used in lieu, with unnecessary overhead and poor failure handling. A stateless (from Flight's perspective) Flight format extension is desirable to close this gap for server implementations that use/want connection state/context. ### What changes are included in this PR? "Session" set/get/close Actions and server-side helper middleware. ### Are these changes tested? Integration tests (C++ currently broken due to middleware-related framework issue) and some complex-case unit testing are included. ### Are there any user-facing changes? Non-breaking extensions to wire format and corresponding client/server Flight RPC API extensions. * Closes: #34865 Lead-authored-by: Paul Nienaber Co-authored-by: Paul Nienaber Co-authored-by: James Duong Co-authored-by: Sutou Kouhei Signed-off-by: David Li # Conflicts: # java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java # java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java # testing --- cpp/src/arrow/flight/client.cc | 41 ++ cpp/src/arrow/flight/client.h | 21 + .../flight_integration_test.cc | 2 + .../integration_tests/test_integration.cc | 154 ++++++++ .../arrow/flight/serialization_internal.cc | 159 ++++++++ cpp/src/arrow/flight/serialization_internal.h | 20 + cpp/src/arrow/flight/sql/CMakeLists.txt | 7 +- cpp/src/arrow/flight/sql/client.h | 27 ++ cpp/src/arrow/flight/sql/server.cc | 76 +++- cpp/src/arrow/flight/sql/server.h | 19 + .../flight/sql/server_session_middleware.cc | 235 ++++++++++++ .../flight/sql/server_session_middleware.h | 89 +++++ .../sql/server_session_middleware_factory.h | 61 +++ ...erver_session_middleware_internals_test.cc | 45 +++ .../flight/transport/grpc/grpc_server.cc | 26 +- cpp/src/arrow/flight/types.cc | 363 ++++++++++++++++++ cpp/src/arrow/flight/types.h | 197 ++++++++++ dev/archery/archery/integration/runner.py | 5 + docs/source/format/FlightSql.rst | 41 ++ format/Flight.proto | 114 ++++++ .../arrow/flight/CloseSessionRequest.java | 58 +++ .../arrow/flight/CloseSessionResult.java | 106 +++++ .../org/apache/arrow/flight/FlightClient.java | 96 +++++ .../apache/arrow/flight/FlightConstants.java | 14 + .../flight/GetSessionOptionsRequest.java | 60 +++ .../arrow/flight/GetSessionOptionsResult.java | 80 ++++ .../flight/NoOpSessionOptionValueVisitor.java | 72 ++++ .../arrow/flight/ServerSessionMiddleware.java | 227 +++++++++++ .../arrow/flight/SessionOptionValue.java | 94 +++++ .../flight/SessionOptionValueFactory.java | 284 ++++++++++++++ .../flight/SessionOptionValueVisitor.java | 58 +++ .../flight/SetSessionOptionsRequest.java | 81 ++++ .../arrow/flight/SetSessionOptionsResult.java | 152 ++++++++ java/flight/flight-integration-tests/pom.xml | 4 + .../flight/integration/tests/Scenarios.java | 2 + .../tests/SessionOptionsProducer.java | 110 ++++++ .../tests/SessionOptionsScenario.java | 107 ++++++ .../integration/tests/IntegrationTest.java | 10 + .../sql/CloseSessionResultListener.java | 46 +++ .../arrow/flight/sql/FlightSqlClient.java | 18 + .../arrow/flight/sql/FlightSqlProducer.java | 79 ++++ .../sql/GetSessionOptionsResultListener.java | 46 +++ .../sql/SetSessionOptionsResultListener.java | 46 +++ testing | 2 +- 44 files changed, 3533 insertions(+), 21 deletions(-) create mode 100644 cpp/src/arrow/flight/sql/server_session_middleware.cc create mode 100644 cpp/src/arrow/flight/sql/server_session_middleware.h create mode 100644 cpp/src/arrow/flight/sql/server_session_middleware_factory.h create mode 100644 cpp/src/arrow/flight/sql/server_session_middleware_internals_test.cc create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/CloseSessionRequest.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/CloseSessionResult.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/GetSessionOptionsRequest.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/GetSessionOptionsResult.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpSessionOptionValueVisitor.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerSessionMiddleware.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValue.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValueFactory.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValueVisitor.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/SetSessionOptionsRequest.java create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/SetSessionOptionsResult.java create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/SessionOptionsProducer.java create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/SessionOptionsScenario.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CloseSessionResultListener.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/GetSessionOptionsResultListener.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SetSessionOptionsResultListener.java diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 25da5e8007660..4d4f13a09fb26 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -713,6 +713,47 @@ arrow::Result FlightClient::DoExchange( return result; } +::arrow::Result FlightClient::SetSessionOptions( + const FlightCallOptions& options, const SetSessionOptionsRequest& request) { + RETURN_NOT_OK(CheckOpen()); + ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString()); + Action action{ActionType::kSetSessionOptions.type, Buffer::FromString(body)}; + ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action)); + ARROW_ASSIGN_OR_RAISE(auto result, stream->Next()); + ARROW_ASSIGN_OR_RAISE( + auto set_session_options_result, + SetSessionOptionsResult::Deserialize(std::string_view(*result->body))); + ARROW_RETURN_NOT_OK(stream->Drain()); + return set_session_options_result; +} + +::arrow::Result FlightClient::GetSessionOptions( + const FlightCallOptions& options, const GetSessionOptionsRequest& request) { + RETURN_NOT_OK(CheckOpen()); + ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString()); + Action action{ActionType::kGetSessionOptions.type, Buffer::FromString(body)}; + ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action)); + ARROW_ASSIGN_OR_RAISE(auto result, stream->Next()); + ARROW_ASSIGN_OR_RAISE( + auto get_session_options_result, + GetSessionOptionsResult::Deserialize(std::string_view(*result->body))); + ARROW_RETURN_NOT_OK(stream->Drain()); + return get_session_options_result; +} + +::arrow::Result FlightClient::CloseSession( + const FlightCallOptions& options, const CloseSessionRequest& request) { + RETURN_NOT_OK(CheckOpen()); + ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString()); + Action action{ActionType::kCloseSession.type, Buffer::FromString(body)}; + ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action)); + ARROW_ASSIGN_OR_RAISE(auto result, stream->Next()); + ARROW_ASSIGN_OR_RAISE(auto close_session_result, + CloseSessionResult::Deserialize(std::string_view(*result->body))); + ARROW_RETURN_NOT_OK(stream->Drain()); + return close_session_result; +} + Status FlightClient::Close() { if (!closed_) { closed_ = true; diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index e26a821359781..d739bd20b7d52 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -383,6 +383,27 @@ class ARROW_FLIGHT_EXPORT FlightClient { return DoExchange({}, descriptor); } + /// \brief Set server session option(s) by name/value. Sessions are generally + /// persisted via HTTP cookies. + /// \param[in] options Per-RPC options + /// \param[in] request The server session options to set + ::arrow::Result SetSessionOptions( + const FlightCallOptions& options, const SetSessionOptionsRequest& request); + + /// \brief Get the current server session options. The session is generally + /// accessed via an HTTP cookie. + /// \param[in] options Per-RPC options + /// \param[in] request The (empty) GetSessionOptions request object. + ::arrow::Result GetSessionOptions( + const FlightCallOptions& options, const GetSessionOptionsRequest& request); + + /// \brief Close/invalidate the current server session. The session is generally + /// accessed via an HTTP cookie. + /// \param[in] options Per-RPC options + /// \param[in] request The (empty) CloseSession request object. + ::arrow::Result CloseSession(const FlightCallOptions& options, + const CloseSessionRequest& request); + /// \brief Explicitly shut down and clean up the client. /// /// For backwards compatibility, this will be implicitly called by diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc index 67c7ee85f59d3..6f3115cc5ab8a 100644 --- a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc +++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc @@ -71,6 +71,8 @@ TEST(FlightIntegration, ExpirationTimeRenewFlightEndpoint) { ASSERT_OK(RunScenario("expiration_time:renew_flight_endpoint")); } +TEST(FlightIntegration, SessionOptions) { ASSERT_OK(RunScenario("session_options")); } + TEST(FlightIntegration, PollFlightInfo) { ASSERT_OK(RunScenario("poll_flight_info")); } TEST(FlightIntegration, AppMetadataFlightInfoEndpoint) { diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 31bffd7704474..9d82c2a67d2d2 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -28,11 +28,13 @@ #include "arrow/array/array_nested.h" #include "arrow/array/array_primitive.h" #include "arrow/array/builder_primitive.h" +#include "arrow/flight/client_cookie_middleware.h" #include "arrow/flight/client_middleware.h" #include "arrow/flight/server_middleware.h" #include "arrow/flight/sql/client.h" #include "arrow/flight/sql/column_metadata.h" #include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/server_session_middleware.h" #include "arrow/flight/sql/types.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" @@ -744,6 +746,155 @@ class ExpirationTimeRenewFlightEndpointScenario : public Scenario { } }; +/// \brief The server used for testing Session Options. +/// +/// SetSessionOptions has a blacklisted option name and string option value, +/// both "lol_invalid", which will result in errors attempting to set either. +class SessionOptionsServer : public sql::FlightSqlServerBase { + static inline const std::string invalid_option_name = "lol_invalid"; + static inline const SessionOptionValue invalid_option_value = "lol_invalid"; + + const std::string session_middleware_key; + // These will never be threaded so using a plain map and no lock + std::map session_store_; + + public: + explicit SessionOptionsServer(std::string session_middleware_key) + : FlightSqlServerBase(), + session_middleware_key(std::move(session_middleware_key)) {} + + arrow::Result SetSessionOptions( + const ServerCallContext& context, + const SetSessionOptionsRequest& request) override { + SetSessionOptionsResult res; + + auto* middleware = static_cast( + context.GetMiddleware(session_middleware_key)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr session, + middleware->GetSession()); + + for (const auto& [name, value] : request.session_options) { + // Blacklisted value name + if (name == invalid_option_name) { + res.errors.emplace(name, SetSessionOptionsResult::Error{ + SetSessionOptionErrorValue::kInvalidName}); + continue; + } + // Blacklisted option value + if (value == invalid_option_value) { + res.errors.emplace(name, SetSessionOptionsResult::Error{ + SetSessionOptionErrorValue::kInvalidValue}); + continue; + } + if (std::holds_alternative(value)) { + session->EraseSessionOption(name); + continue; + } + session->SetSessionOption(name, value); + } + + return res; + } + + arrow::Result GetSessionOptions( + const ServerCallContext& context, + const GetSessionOptionsRequest& request) override { + auto* middleware = static_cast( + context.GetMiddleware(session_middleware_key)); + if (!middleware->HasSession()) { + return Status::Invalid("No existing session to get options from."); + } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr session, + middleware->GetSession()); + + return GetSessionOptionsResult{session->GetSessionOptions()}; + } + + arrow::Result CloseSession( + const ServerCallContext& context, const CloseSessionRequest& request) override { + // Broken (does not expire cookie) until C++ middleware handling (GH-39791) fixed: + auto* middleware = static_cast( + context.GetMiddleware(session_middleware_key)); + ARROW_RETURN_NOT_OK(middleware->CloseSession()); + return CloseSessionResult{CloseSessionStatus::kClosed}; + } +}; + +/// \brief The Session Options scenario. +/// +/// This tests Session Options functionality as well as ServerSessionMiddleware. +class SessionOptionsScenario : public Scenario { + static inline const std::string server_middleware_key = "sessionmiddleware"; + + Status MakeServer(std::unique_ptr* server, + FlightServerOptions* options) override { + *server = std::make_unique(server_middleware_key); + + auto id_gen_int = std::make_shared(1000); + options->middleware.emplace_back( + server_middleware_key, + sql::MakeServerSessionMiddlewareFactory( + [=]() -> std::string { return std::to_string((*id_gen_int)++); })); + + return Status::OK(); + } + + Status MakeClient(FlightClientOptions* options) override { + options->middleware.emplace_back(GetCookieFactory()); + return Status::OK(); + } + + Status RunClient(std::unique_ptr flight_client) override { + sql::FlightSqlClient client{std::move(flight_client)}; + + // Set + auto req1 = SetSessionOptionsRequest{ + {{"foolong", 123L}, + {"bardouble", 456.0}, + {"lol_invalid", "this won't get set"}, + {"key_with_invalid_value", "lol_invalid"}, + {"big_ol_string_list", std::vector{"a", "b", "sea", "dee", " ", + " ", "geee", "(づ。◕‿‿◕。)づ"}}}}; + ARROW_ASSIGN_OR_RAISE(auto res1, client.SetSessionOptions({}, req1)); + // Some errors + if (res1.errors != + std::map{ + {"lol_invalid", + SetSessionOptionsResult::Error{SetSessionOptionErrorValue::kInvalidName}}, + {"key_with_invalid_value", SetSessionOptionsResult::Error{ + SetSessionOptionErrorValue::kInvalidValue}}}) { + return Status::Invalid("res1 incorrect: " + res1.ToString()); + } + // Some set, some omitted due to above errors + ARROW_ASSIGN_OR_RAISE(auto res2, client.GetSessionOptions({}, {})); + if (res2.session_options != + std::map{ + {"foolong", 123L}, + {"bardouble", 456.0}, + {"big_ol_string_list", + std::vector{"a", "b", "sea", "dee", " ", " ", "geee", + "(づ。◕‿‿◕。)づ"}}}) { + return Status::Invalid("res2 incorrect: " + res2.ToString()); + } + // Update + ARROW_ASSIGN_OR_RAISE( + auto res3, + client.SetSessionOptions( + {}, SetSessionOptionsRequest{ + {{"foolong", std::monostate{}}, + {"big_ol_string_list", "a,b,sea,dee, , ,geee,(づ。◕‿‿◕。)づ"}}})); + ARROW_ASSIGN_OR_RAISE(auto res4, client.GetSessionOptions({}, {})); + if (res4.session_options != + std::map{ + {"bardouble", 456.0}, + {"big_ol_string_list", "a,b,sea,dee, , ,geee,(づ。◕‿‿◕。)づ"}}) { + return Status::Invalid("res4 incorrect: " + res4.ToString()); + } + + return Status::OK(); + } +}; + /// \brief The server used for testing PollFlightInfo(). class PollFlightInfoServer : public FlightServerBase { public: @@ -1952,6 +2103,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr* } else if (scenario_name == "expiration_time:renew_flight_endpoint") { *out = std::make_shared(); return Status::OK(); + } else if (scenario_name == "session_options") { + *out = std::make_shared(); + return Status::OK(); } else if (scenario_name == "poll_flight_info") { *out = std::make_shared(); return Status::OK(); diff --git a/cpp/src/arrow/flight/serialization_internal.cc b/cpp/src/arrow/flight/serialization_internal.cc index 64a40564afd72..238a2028f5d1a 100644 --- a/cpp/src/arrow/flight/serialization_internal.cc +++ b/cpp/src/arrow/flight/serialization_internal.cc @@ -27,6 +27,14 @@ #include "arrow/result.h" #include "arrow/status.h" +// Lambda helper & CTAD +template +struct overloaded : Ts... { + using Ts::operator()...; +}; +template // CTAD will not be needed for >=C++20 +overloaded(Ts...)->overloaded; + namespace arrow { namespace flight { namespace internal { @@ -376,6 +384,157 @@ Status ToPayload(const FlightDescriptor& descr, std::shared_ptr* out) { return Status::OK(); } +// SessionOptionValue + +Status FromProto(const pb::SessionOptionValue& pb_val, SessionOptionValue* val) { + switch (pb_val.option_value_case()) { + case pb::SessionOptionValue::OPTION_VALUE_NOT_SET: + *val = std::monostate{}; + break; + case pb::SessionOptionValue::kStringValue: + *val = pb_val.string_value(); + break; + case pb::SessionOptionValue::kBoolValue: + *val = pb_val.bool_value(); + break; + case pb::SessionOptionValue::kInt64Value: + *val = pb_val.int64_value(); + break; + case pb::SessionOptionValue::kDoubleValue: + *val = pb_val.double_value(); + break; + case pb::SessionOptionValue::kStringListValue: { + std::vector vec; + vec.reserve(pb_val.string_list_value().values_size()); + for (const std::string& s : pb_val.string_list_value().values()) { + vec.push_back(s); + } + (*val).emplace>(std::move(vec)); + break; + } + } + return Status::OK(); +} + +Status ToProto(const SessionOptionValue& val, pb::SessionOptionValue* pb_val) { + std::visit(overloaded{[&](std::monostate v) { pb_val->clear_option_value(); }, + [&](std::string v) { pb_val->set_string_value(v); }, + [&](bool v) { pb_val->set_bool_value(v); }, + [&](int64_t v) { pb_val->set_int64_value(v); }, + [&](double v) { pb_val->set_double_value(v); }, + [&](std::vector v) { + auto* string_list_value = pb_val->mutable_string_list_value(); + for (const std::string& s : v) string_list_value->add_values(s); + }}, + val); + return Status::OK(); +} + +// map + +Status FromProto(const google::protobuf::Map& pb_map, + std::map* map) { + if (pb_map.empty()) { + return Status::OK(); + } + for (const auto& [name, pb_val] : pb_map) { + RETURN_NOT_OK(FromProto(pb_val, &(*map)[name])); + } + return Status::OK(); +} + +Status ToProto(const std::map& map, + google::protobuf::Map* pb_map) { + for (const auto& [name, val] : map) { + RETURN_NOT_OK(ToProto(val, &(*pb_map)[name])); + } + return Status::OK(); +} + +// SetSessionOptionsRequest + +Status FromProto(const pb::SetSessionOptionsRequest& pb_request, + SetSessionOptionsRequest* request) { + RETURN_NOT_OK(FromProto(pb_request.session_options(), &request->session_options)); + return Status::OK(); +} + +Status ToProto(const SetSessionOptionsRequest& request, + pb::SetSessionOptionsRequest* pb_request) { + RETURN_NOT_OK(ToProto(request.session_options, pb_request->mutable_session_options())); + return Status::OK(); +} + +// SetSessionOptionsResult + +Status FromProto(const pb::SetSessionOptionsResult& pb_result, + SetSessionOptionsResult* result) { + for (const auto& [k, pb_v] : pb_result.errors()) { + result->errors.insert({k, {static_cast(pb_v.value())}}); + } + return Status::OK(); +} + +Status ToProto(const SetSessionOptionsResult& result, + pb::SetSessionOptionsResult* pb_result) { + auto* pb_errors = pb_result->mutable_errors(); + for (const auto& [k, v] : result.errors) { + pb::SetSessionOptionsResult::Error e; + e.set_value(static_cast(v.value)); + (*pb_errors)[k] = std::move(e); + } + return Status::OK(); +} + +// GetSessionOptionsRequest + +Status FromProto(const pb::GetSessionOptionsRequest& pb_request, + GetSessionOptionsRequest* request) { + return Status::OK(); +} + +Status ToProto(const GetSessionOptionsRequest& request, + pb::GetSessionOptionsRequest* pb_request) { + return Status::OK(); +} + +// GetSessionOptionsResult + +Status FromProto(const pb::GetSessionOptionsResult& pb_result, + GetSessionOptionsResult* result) { + RETURN_NOT_OK(FromProto(pb_result.session_options(), &result->session_options)); + return Status::OK(); +} + +Status ToProto(const GetSessionOptionsResult& result, + pb::GetSessionOptionsResult* pb_result) { + RETURN_NOT_OK(ToProto(result.session_options, pb_result->mutable_session_options())); + return Status::OK(); +} + +// CloseSessionRequest + +Status FromProto(const pb::CloseSessionRequest& pb_request, + CloseSessionRequest* request) { + return Status::OK(); +} + +Status ToProto(const CloseSessionRequest& request, pb::CloseSessionRequest* pb_request) { + return Status::OK(); +} + +// CloseSessionResult + +Status FromProto(const pb::CloseSessionResult& pb_result, CloseSessionResult* result) { + result->status = static_cast(pb_result.status()); + return Status::OK(); +} + +Status ToProto(const CloseSessionResult& result, pb::CloseSessionResult* pb_result) { + pb_result->set_status(static_cast(result.status)); + return Status::OK(); +} + } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h index 1ac7de83d1308..90dde87d3a5eb 100644 --- a/cpp/src/arrow/flight/serialization_internal.h +++ b/cpp/src/arrow/flight/serialization_internal.h @@ -66,6 +66,16 @@ Status FromProto(const pb::CancelFlightInfoRequest& pb_request, CancelFlightInfoRequest* request); Status FromProto(const pb::SchemaResult& pb_result, std::string* result); Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* info); +Status FromProto(const pb::SetSessionOptionsRequest& pb_request, + SetSessionOptionsRequest* request); +Status FromProto(const pb::SetSessionOptionsResult& pb_result, + SetSessionOptionsResult* result); +Status FromProto(const pb::GetSessionOptionsRequest& pb_request, + GetSessionOptionsRequest* request); +Status FromProto(const pb::GetSessionOptionsResult& pb_result, + GetSessionOptionsResult* result); +Status FromProto(const pb::CloseSessionRequest& pb_request, CloseSessionRequest* request); +Status FromProto(const pb::CloseSessionResult& pb_result, CloseSessionResult* result); Status ToProto(const Timestamp& timestamp, google::protobuf::Timestamp* pb_timestamp); Status ToProto(const FlightDescriptor& descr, pb::FlightDescriptor* pb_descr); @@ -85,6 +95,16 @@ Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria); Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result); Status ToProto(const Ticket& ticket, pb::Ticket* pb_ticket); Status ToProto(const BasicAuth& basic_auth, pb::BasicAuth* pb_basic_auth); +Status ToProto(const SetSessionOptionsRequest& request, + pb::SetSessionOptionsRequest* pb_request); +Status ToProto(const SetSessionOptionsResult& result, + pb::SetSessionOptionsResult* pb_result); +Status ToProto(const GetSessionOptionsRequest& request, + pb::GetSessionOptionsRequest* pb_request); +Status ToProto(const GetSessionOptionsResult& result, + pb::GetSessionOptionsResult* pb_result); +Status ToProto(const CloseSessionRequest& request, pb::CloseSessionRequest* pb_request); +Status ToProto(const CloseSessionResult& result, pb::CloseSessionResult* pb_result); Status ToPayload(const FlightDescriptor& descr, std::shared_ptr* out); diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index b0a551a2bca77..b32f731496749 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -47,7 +47,8 @@ set(ARROW_FLIGHT_SQL_SRCS sql_info_internal.cc column_metadata.cc client.cc - protocol_internal.cc) + protocol_internal.cc + server_session_middleware.cc) add_arrow_lib(arrow_flight_sql CMAKE_PACKAGE_NAME @@ -104,7 +105,9 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) example/sqlite_server.cc example/sqlite_tables_schema_batch_reader.cc) - set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc) + set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc + server_session_middleware_internals_test.cc) + set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES}) set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc) diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 5f3fc7d8574a9..9782611dbadcd 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -350,6 +350,33 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { ::arrow::Result CancelQuery(const FlightCallOptions& options, const FlightInfo& info); + /// \brief Sets session options. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] request The session options to set. + ::arrow::Result SetSessionOptions( + const FlightCallOptions& options, const SetSessionOptionsRequest& request) { + return impl_->SetSessionOptions(options, request); + } + + /// \brief Gets current session options. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] request The (empty) GetSessionOptions request object. + ::arrow::Result GetSessionOptions( + const FlightCallOptions& options, const GetSessionOptionsRequest& request) { + return impl_->GetSessionOptions(options, request); + } + + /// \brief Explicitly closes the session if applicable. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] request The (empty) CloseSession request object. + ::arrow::Result CloseSession(const FlightCallOptions& options, + const CloseSessionRequest& request) { + return impl_->CloseSession(options, request); + } + /// \brief Extends the expiration of a FlightEndpoint. /// /// \param[in] options RPC-layer hints for this call. diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index a6d197d15b2c0..a5cb842de8f49 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -442,6 +442,21 @@ arrow::Result PackActionResult(ActionCreatePreparedStatementResult resul return PackActionResult(pb_result); } +arrow::Result PackActionResult(SetSessionOptionsResult result) { + ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()); + return Result{Buffer::FromString(std::move(serialized))}; +} + +arrow::Result PackActionResult(GetSessionOptionsResult result) { + ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()); + return Result{Buffer::FromString(std::move(serialized))}; +} + +arrow::Result PackActionResult(CloseSessionResult result) { + ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()); + return Result{Buffer::FromString(std::move(serialized))}; +} + } // namespace arrow::Result StatementQueryTicket::Deserialize( @@ -759,18 +774,19 @@ Status FlightSqlServerBase::DoPut(const ServerCallContext& context, Status FlightSqlServerBase::ListActions(const ServerCallContext& context, std::vector* actions) { - *actions = { - ActionType::kCancelFlightInfo, - ActionType::kRenewFlightEndpoint, - FlightSqlServerBase::kBeginSavepointActionType, - FlightSqlServerBase::kBeginTransactionActionType, - FlightSqlServerBase::kCancelQueryActionType, - FlightSqlServerBase::kCreatePreparedStatementActionType, - FlightSqlServerBase::kCreatePreparedSubstraitPlanActionType, - FlightSqlServerBase::kClosePreparedStatementActionType, - FlightSqlServerBase::kEndSavepointActionType, - FlightSqlServerBase::kEndTransactionActionType, - }; + *actions = {ActionType::kCancelFlightInfo, + ActionType::kRenewFlightEndpoint, + FlightSqlServerBase::kBeginSavepointActionType, + FlightSqlServerBase::kBeginTransactionActionType, + FlightSqlServerBase::kCancelQueryActionType, + FlightSqlServerBase::kCreatePreparedStatementActionType, + FlightSqlServerBase::kCreatePreparedSubstraitPlanActionType, + FlightSqlServerBase::kClosePreparedStatementActionType, + FlightSqlServerBase::kEndSavepointActionType, + FlightSqlServerBase::kEndTransactionActionType, + ActionType::kSetSessionOptions, + ActionType::kGetSessionOptions, + ActionType::kCloseSession}; return Status::OK(); } @@ -791,6 +807,27 @@ Status FlightSqlServerBase::DoAction(const ServerCallContext& context, ARROW_ASSIGN_OR_RAISE(auto renewed_endpoint, RenewFlightEndpoint(context, request)); ARROW_ASSIGN_OR_RAISE(auto packed_result, PackActionResult(renewed_endpoint)); + results.push_back(std::move(packed_result)); + } else if (action.type == ActionType::kSetSessionOptions.type) { + std::string_view body(*action.body); + ARROW_ASSIGN_OR_RAISE(auto request, SetSessionOptionsRequest::Deserialize(body)); + ARROW_ASSIGN_OR_RAISE(auto result, SetSessionOptions(context, request)); + ARROW_ASSIGN_OR_RAISE(auto packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == ActionType::kGetSessionOptions.type) { + std::string_view body(*action.body); + ARROW_ASSIGN_OR_RAISE(auto request, GetSessionOptionsRequest::Deserialize(body)); + ARROW_ASSIGN_OR_RAISE(auto result, GetSessionOptions(context, request)); + ARROW_ASSIGN_OR_RAISE(auto packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == ActionType::kCloseSession.type) { + std::string_view body(*action.body); + ARROW_ASSIGN_OR_RAISE(auto request, CloseSessionRequest::Deserialize(body)); + ARROW_ASSIGN_OR_RAISE(auto result, CloseSession(context, request)); + ARROW_ASSIGN_OR_RAISE(auto packed_result, PackActionResult(std::move(result))); + results.push_back(std::move(packed_result)); } else { google::protobuf::Any any; @@ -1098,6 +1135,11 @@ arrow::Result FlightSqlServerBase::RenewFlightEndpoint( return Status::NotImplemented("RenewFlightEndpoint not implemented"); } +arrow::Result FlightSqlServerBase::CloseSession( + const ServerCallContext& context, const CloseSessionRequest& request) { + return Status::NotImplemented("CloseSession not implemented"); +} + arrow::Result FlightSqlServerBase::CreatePreparedStatement( const ServerCallContext& context, @@ -1128,6 +1170,16 @@ Status FlightSqlServerBase::EndTransaction(const ServerCallContext& context, return Status::NotImplemented("EndTransaction not implemented"); } +arrow::Result FlightSqlServerBase::SetSessionOptions( + const ServerCallContext& context, const SetSessionOptionsRequest& request) { + return Status::NotImplemented("SetSessionOptions not implemented"); +} + +arrow::Result FlightSqlServerBase::GetSessionOptions( + const ServerCallContext& context, const GetSessionOptionsRequest& request) { + return Status::NotImplemented("GetSessionOptions not implemented"); +} + Status FlightSqlServerBase::DoPutPreparedStatementQuery( const ServerCallContext& context, const PreparedStatementQuery& command, FlightMessageReader* reader, FlightMetadataWriter* writer) { diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h index 360677c078c81..6c44f46508c7f 100644 --- a/cpp/src/arrow/flight/sql/server.h +++ b/cpp/src/arrow/flight/sql/server.h @@ -20,6 +20,7 @@ #pragma once +#include #include #include #include @@ -601,6 +602,24 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { virtual arrow::Result CancelFlightInfo( const ServerCallContext& context, const CancelFlightInfoRequest& request); + /// \brief Set server session option(s). + /// \param[in] context The call context. + /// \param[in] request The session options to set. + virtual arrow::Result SetSessionOptions( + const ServerCallContext& context, const SetSessionOptionsRequest& request); + + /// \brief Get server session option(s). + /// \param[in] context The call context. + /// \param[in] request Request object. + virtual arrow::Result GetSessionOptions( + const ServerCallContext& context, const GetSessionOptionsRequest& request); + + /// \brief Close/invalidate the session. + /// \param[in] context The call context. + /// \param[in] request Request object. + virtual arrow::Result CloseSession( + const ServerCallContext& context, const CloseSessionRequest& request); + /// \brief Attempt to explicitly cancel a query. /// /// \param[in] context The call context. diff --git a/cpp/src/arrow/flight/sql/server_session_middleware.cc b/cpp/src/arrow/flight/sql/server_session_middleware.cc new file mode 100644 index 0000000000000..f3e02de232444 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server_session_middleware.cc @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/flight/sql/server_session_middleware.h" +#include "arrow/flight/sql/server_session_middleware_factory.h" + +namespace arrow { +namespace flight { +namespace sql { + +class ServerSessionMiddlewareImpl : public ServerSessionMiddleware { + protected: + std::shared_mutex mutex_; + ServerSessionMiddlewareFactory* factory_; + const CallHeaders& headers_; + std::shared_ptr session_; + std::string session_id_; + std::string closed_session_id_; + bool existing_session_; + + public: + ServerSessionMiddlewareImpl(ServerSessionMiddlewareFactory* factory, + const CallHeaders& headers) + : factory_(factory), headers_(headers), existing_session_(false) {} + + ServerSessionMiddlewareImpl(ServerSessionMiddlewareFactory* factory, + const CallHeaders& headers, + std::shared_ptr session, + std::string session_id, bool existing_session = true) + : factory_(factory), + headers_(headers), + session_(std::move(session)), + session_id_(std::move(session_id)), + existing_session_(existing_session) {} + + void SendingHeaders(AddCallHeaders* add_call_headers) override { + if (!existing_session_ && session_) { + add_call_headers->AddHeader( + "set-cookie", static_cast(kSessionCookieName) + "=" + session_id_); + } + if (!closed_session_id_.empty()) { + add_call_headers->AddHeader( + "set-cookie", static_cast(kSessionCookieName) + "=" + session_id_ + + "; Max-Age=0"); + } + } + + void CallCompleted(const Status&) override {} + + bool HasSession() const override { return static_cast(session_); } + + arrow::Result> GetSession() override { + const std::lock_guard l(mutex_); + if (!session_) { + auto [id, s] = factory_->CreateNewSession(); + session_ = std::move(s); + session_id_ = std::move(id); + } + if (!static_cast(session_)) { + return Status::UnknownError("Error creating session."); + } + return session_; + } + + Status CloseSession() override { + const std::lock_guard l(mutex_); + if (static_cast(session_)) { + return Status::Invalid("Nonexistent session cannot be closed."); + } + ARROW_RETURN_NOT_OK(factory_->CloseSession(session_id_)); + closed_session_id_ = std::move(session_id_); + session_id_.clear(); + session_.reset(); + existing_session_ = false; + + return Status::OK(); + } + + const CallHeaders& GetCallHeaders() const override { return headers_; } +}; + +std::vector> +ServerSessionMiddlewareFactory::ParseCookieString(const std::string_view& s) { + const std::string list_sep = "; "; + const std::string pair_sep = "="; + + std::vector> result; + + size_t cur = 0; + while (cur < s.length()) { + const size_t end = s.find(list_sep, cur); + const bool further_pairs = end != std::string::npos; + const size_t len = further_pairs ? end - cur : std::string::npos; + const std::string_view tok = s.substr(cur, len); + cur = further_pairs ? end + list_sep.length() : s.length(); + + const size_t val_pos = tok.find(pair_sep); + if (val_pos == std::string::npos) { + // The cookie header is somewhat malformed; ignore the key and continue parsing + continue; + } + const std::string_view cookie_name = tok.substr(0, val_pos); + std::string_view cookie_value = + tok.substr(val_pos + pair_sep.length(), std::string::npos); + if (cookie_name.empty()) { + continue; + } + // Strip doublequotes + if (cookie_value.length() >= 2 && cookie_value.front() == '"' && + cookie_value.back() == '"') { + cookie_value.remove_prefix(1); + cookie_value.remove_suffix(1); + } + result.emplace_back(cookie_name, cookie_value); + } + + return result; +} + +Status ServerSessionMiddlewareFactory::StartCall( + const CallInfo&, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) { + std::string session_id; + + const std::pair& + headers_it_pr = incoming_headers.equal_range("cookie"); + for (auto itr = headers_it_pr.first; itr != headers_it_pr.second; ++itr) { + const std::string_view& cookie_header = itr->second; + const std::vector> cookies = + ParseCookieString(cookie_header); + for (const std::pair& cookie : cookies) { + if (cookie.first == kSessionCookieName) { + if (cookie.second.empty()) + return Status::Invalid("Empty ", kSessionCookieName, " cookie value."); + session_id = std::move(cookie.second); + } + } + if (!session_id.empty()) break; + } + + if (session_id.empty()) { + // No cookie was found + // Temporary workaround until middleware handling fixed + auto [id, s] = CreateNewSession(); + *middleware = std::make_shared(this, incoming_headers, + std::move(s), id, false); + } else { + const std::shared_lock l(session_store_lock_); + if (auto it = session_store_.find(session_id); it == session_store_.end()) { + return Status::Invalid("Invalid or expired ", kSessionCookieName, " cookie."); + } else { + auto session = it->second; + *middleware = std::make_shared( + this, incoming_headers, std::move(session), session_id); + } + } + + return Status::OK(); +} + +/// \brief Get a new, empty session option map & its id key; {"",NULLPTR} on collision. +std::pair> +ServerSessionMiddlewareFactory::CreateNewSession() { + auto new_id = id_generator_(); + auto session = std::make_shared(); + + const std::lock_guard l(session_store_lock_); + if (session_store_.count(new_id)) { + // Collision + return {"", NULLPTR}; + } + session_store_[new_id] = session; + + return {new_id, session}; +} + +Status ServerSessionMiddlewareFactory::CloseSession(std::string id) { + const std::lock_guard l(session_store_lock_); + if (!session_store_.erase(id)) { + return Status::KeyError("Invalid or nonexistent session cannot be closed."); + } + return Status::OK(); +} + +std::shared_ptr MakeServerSessionMiddlewareFactory( + std::function id_gen) { + return std::make_shared(std::move(id_gen)); +} + +std::optional FlightSession::GetSessionOption( + const std::string& name) { + const std::shared_lock l(map_lock_); + auto it = map_.find(name); + if (it != map_.end()) { + return it->second; + } else { + return std::nullopt; + } +} + +std::map FlightSession::GetSessionOptions() { + const std::shared_lock l(map_lock_); + return map_; +} + +void FlightSession::SetSessionOption(const std::string& name, + const SessionOptionValue value) { + const std::lock_guard l(map_lock_); + map_[name] = std::move(value); +} + +void FlightSession::EraseSessionOption(const std::string& name) { + const std::lock_guard l(map_lock_); + map_.erase(name); +} + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/server_session_middleware.h b/cpp/src/arrow/flight/sql/server_session_middleware.h new file mode 100644 index 0000000000000..021793de3de32 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server_session_middleware.h @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Middleware for handling Flight SQL Sessions including session cookie handling. +// Currently experimental. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/types.h" +#include "arrow/status.h" + +namespace arrow { +namespace flight { +namespace sql { + +static constexpr char const kSessionCookieName[] = "arrow_flight_session_id"; + +class ARROW_FLIGHT_SQL_EXPORT FlightSession { + protected: + std::map map_; + std::shared_mutex map_lock_; + + public: + /// \brief Get session option by name + std::optional GetSessionOption(const std::string& name); + /// \brief Get a copy of the session options map. + /// + /// The returned options map may be modified by further calls to this FlightSession + std::map GetSessionOptions(); + /// \brief Set session option by name to given value + void SetSessionOption(const std::string& name, const SessionOptionValue value); + /// \brief Idempotently remove name from this session + void EraseSessionOption(const std::string& name); +}; + +/// \brief A middleware to handle session option persistence and related cookie headers. +/// +/// WARNING that client cookie invalidation does not currently work due to a gRPC +/// transport bug. +class ARROW_FLIGHT_SQL_EXPORT ServerSessionMiddleware : public ServerMiddleware { + public: + static constexpr char const kMiddlewareName[] = + "arrow::flight::sql::ServerSessionMiddleware"; + + std::string name() const override { return kMiddlewareName; } + + /// \brief Is there an existing session (either existing or new) + virtual bool HasSession() const = 0; + /// \brief Get existing or new call-associated session + /// + /// May return NULLPTR if there is an id generation collision. + virtual arrow::Result> GetSession() = 0; + /// Close the current session. + /// + /// This is presently unsupported in C++ until middleware handling can be fixed. + virtual Status CloseSession() = 0; + /// \brief Get request headers, in lieu of a provided or created session. + virtual const CallHeaders& GetCallHeaders() const = 0; +}; + +/// \brief Returns a ServerMiddlewareFactory that handles session option storage. +/// \param[in] id_gen A thread-safe, collision-free generator for session id strings. +ARROW_FLIGHT_SQL_EXPORT std::shared_ptr +MakeServerSessionMiddlewareFactory(std::function id_gen); + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/server_session_middleware_factory.h b/cpp/src/arrow/flight/sql/server_session_middleware_factory.h new file mode 100644 index 0000000000000..2613c572eefc2 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server_session_middleware_factory.h @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// ServerSessionMiddlewareFactory, factored into a separate header for testability + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace arrow { +namespace flight { +namespace sql { + +/// \brief A factory for ServerSessionMiddleware, itself storing session data. +class ARROW_FLIGHT_SQL_EXPORT ServerSessionMiddlewareFactory + : public ServerMiddlewareFactory { + protected: + std::map> session_store_; + std::shared_mutex session_store_lock_; + std::function id_generator_; + + static std::vector> ParseCookieString( + const std::string_view& s); + + public: + explicit ServerSessionMiddlewareFactory(std::function id_gen) + : id_generator_(id_gen) {} + Status StartCall(const CallInfo&, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override; + + /// \brief Get a new, empty session option map and its id key. + std::pair> CreateNewSession(); + /// \brief Close the session identified by 'id'. + /// \param id The string id of the session to close. + Status CloseSession(std::string id); +}; + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/server_session_middleware_internals_test.cc b/cpp/src/arrow/flight/sql/server_session_middleware_internals_test.cc new file mode 100644 index 0000000000000..74e4d7845c699 --- /dev/null +++ b/cpp/src/arrow/flight/sql/server_session_middleware_internals_test.cc @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// ---------------------------------------------------------------------- +// ServerSessionMiddleware{,Factory} tests not involing a client/server instance + +#include + +#include + +namespace arrow { +namespace flight { +namespace sql { + +class ServerSessionMiddlewareFactoryPrivate : public ServerSessionMiddlewareFactory { + public: + using ServerSessionMiddlewareFactory::ParseCookieString; +}; + +TEST(ServerSessionMiddleware, ParseCookieString) { + std::vector> r1 = + ServerSessionMiddlewareFactoryPrivate::ParseCookieString( + "k1=v1; k2=\"v2\"; kempty=; k3=v3"); + std::vector> e1 = { + {"k1", "v1"}, {"k2", "v2"}, {"kempty", ""}, {"k3", "v3"}}; + ASSERT_EQ(e1, r1); +} + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc index bbd01155fe4a4..a9780b5eeb77e 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc @@ -290,7 +290,8 @@ class GrpcServiceHandler final : public FlightService::Service { // Authenticate the client (if applicable) and construct the call context ::grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context, - GrpcServerCallContext& flight_context) { + GrpcServerCallContext& flight_context, + bool skip_headers = false) { if (!auth_handler_) { const auto auth_context = context->auth_context(); if (auth_context && auth_context->IsPeerAuthenticated()) { @@ -320,11 +321,11 @@ class GrpcServiceHandler final : public FlightService::Service { // Authenticate the client (if applicable) and construct the call context ::grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context, - GrpcServerCallContext& flight_context) { + GrpcServerCallContext& flight_context, + bool skip_headers = false) { // Run server middleware const CallInfo info{method}; - GrpcAddServerHeaders outgoing_headers(context); for (const auto& factory : middleware_) { std::shared_ptr instance; Status result = factory.second->StartCall(info, flight_context, &instance); @@ -336,13 +337,25 @@ class GrpcServiceHandler final : public FlightService::Service { if (instance != nullptr) { flight_context.middleware_.push_back(instance); flight_context.middleware_map_.insert({factory.first, instance}); - instance->SendingHeaders(&outgoing_headers); } } + // TODO factor this out after fixing all streaming and non-streaming handlers + if (!skip_headers) { + addMiddlewareHeaders(context, flight_context); + } + return ::grpc::Status::OK; } + void addMiddlewareHeaders(ServerContext* context, + GrpcServerCallContext& flight_context) { + GrpcAddServerHeaders outgoing_headers(context); + for (const std::shared_ptr& instance : flight_context.middleware_) { + instance->SendingHeaders(&outgoing_headers); + } + } + ::grpc::Status Handshake( ServerContext* context, ::grpc::ServerReaderWriter* stream) { @@ -399,8 +412,9 @@ class GrpcServiceHandler final : public FlightService::Service { SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr)); std::unique_ptr info; - SERVICE_RETURN_NOT_OK(flight_context, - impl_->base()->GetFlightInfo(flight_context, descr, &info)); + auto res = impl_->base()->GetFlightInfo(flight_context, descr, &info); + addMiddlewareHeaders(context, flight_context); + SERVICE_RETURN_NOT_OK(flight_context, res); if (!info) { // Treat null listing as no flights available diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 9da83fa8a11f2..fe673f9c3544d 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -17,9 +17,11 @@ #include "arrow/flight/types.h" +#include #include #include #include +#include #include #include "arrow/buffer.h" @@ -468,6 +470,352 @@ arrow::Result CancelFlightInfoRequest::Deserialize( return out; } +static const char* const SetSessionOptionStatusNames[] = {"Unspecified", "InvalidName", + "InvalidValue", "Error"}; +static const char* const CloseSessionStatusNames[] = {"Unspecified", "Closed", "Closing", + "NotClosable"}; + +// Helpers for stringifying maps containing various types +std::string ToString(const SetSessionOptionErrorValue& error_value) { + return SetSessionOptionStatusNames[static_cast(error_value)]; +} + +std::ostream& operator<<(std::ostream& os, + const SetSessionOptionErrorValue& error_value) { + os << ToString(error_value); + return os; +} + +std::string ToString(const CloseSessionStatus& status) { + return CloseSessionStatusNames[static_cast(status)]; +} + +std::ostream& operator<<(std::ostream& os, const CloseSessionStatus& status) { + os << ToString(status); + return os; +} + +std::ostream& operator<<(std::ostream& os, std::vector values) { + os << '['; + std::string sep = ""; + for (const auto& v : values) { + os << sep << std::quoted(v); + sep = ", "; + } + os << ']'; + + return os; +} + +std::ostream& operator<<(std::ostream& os, const SessionOptionValue& v) { + if (std::holds_alternative(v)) { + os << ""; + } else { + std::visit( + [&](const auto& x) { + if constexpr (std::is_convertible_v, + std::string_view>) { + os << std::quoted(x); + } else { + os << x; + } + }, + v); + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const SetSessionOptionsResult::Error& e) { + os << '{' << e.value << '}'; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, std::map m) { + os << '{'; + std::string sep = ""; + if constexpr (std::is_convertible_v) { + // std::string, char*, std::string_view + for (const auto& [k, v] : m) { + os << sep << '[' << k << "]: " << std::quoted(v) << '"'; + sep = ", "; + } + } else { + for (const auto& [k, v] : m) { + os << sep << '[' << k << "]: " << v; + sep = ", "; + } + } + os << '}'; + + return os; +} + +namespace { +static bool CompareSessionOptionMaps(const std::map& a, + const std::map& b) { + if (a.size() != b.size()) { + return false; + } + for (const auto& [k, v] : a) { + if (const auto it = b.find(k); it == b.end()) { + return false; + } else { + const auto& b_v = it->second; + if (v.index() != b_v.index()) { + return false; + } + if (v != b_v) { + return false; + } + } + } + return true; +} +} // namespace + +// SetSessionOptionsRequest + +std::string SetSessionOptionsRequest::ToString() const { + std::stringstream ss; + + ss << " SetSessionOptionsRequest::Deserialize( + std::string_view serialized) { + // TODO these & SerializeToString should all be factored out to a superclass + pb::SetSessionOptionsRequest pb_request; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid( + "Serialized SetSessionOptionsRequest size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_request.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid SetSessionOptionsRequest"); + } + SetSessionOptionsRequest out; + RETURN_NOT_OK(internal::FromProto(pb_request, &out)); + return out; +} + +// SetSessionOptionsResult + +std::string SetSessionOptionsResult::ToString() const { + std::stringstream ss; + + ss << " SetSessionOptionsResult::Deserialize( + std::string_view serialized) { + pb::SetSessionOptionsResult pb_result; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid( + "Serialized SetSessionOptionsResult size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_result.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid SetSessionOptionsResult"); + } + SetSessionOptionsResult out; + RETURN_NOT_OK(internal::FromProto(pb_result, &out)); + return out; +} + +// GetSessionOptionsRequest + +std::string GetSessionOptionsRequest::ToString() const { + return ""; +} + +bool GetSessionOptionsRequest::Equals(const GetSessionOptionsRequest& other) const { + return true; +} + +arrow::Result GetSessionOptionsRequest::SerializeToString() const { + pb::GetSessionOptionsRequest pb_request; + RETURN_NOT_OK(internal::ToProto(*this, &pb_request)); + + std::string out; + if (!pb_request.SerializeToString(&out)) { + return Status::IOError("Serialized GetSessionOptionsRequest exceeded 2GiB limit"); + } + return out; +} + +arrow::Result GetSessionOptionsRequest::Deserialize( + std::string_view serialized) { + pb::GetSessionOptionsRequest pb_request; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid( + "Serialized GetSessionOptionsRequest size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_request.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid GetSessionOptionsRequest"); + } + GetSessionOptionsRequest out; + RETURN_NOT_OK(internal::FromProto(pb_request, &out)); + return out; +} + +// GetSessionOptionsResult + +std::string GetSessionOptionsResult::ToString() const { + std::stringstream ss; + + ss << " GetSessionOptionsResult::Deserialize( + std::string_view serialized) { + pb::GetSessionOptionsResult pb_result; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid( + "Serialized GetSessionOptionsResult size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_result.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid GetSessionOptionsResult"); + } + GetSessionOptionsResult out; + RETURN_NOT_OK(internal::FromProto(pb_result, &out)); + return out; +} + +// CloseSessionRequest + +std::string CloseSessionRequest::ToString() const { return ""; } + +bool CloseSessionRequest::Equals(const CloseSessionRequest& other) const { return true; } + +arrow::Result CloseSessionRequest::SerializeToString() const { + pb::CloseSessionRequest pb_request; + RETURN_NOT_OK(internal::ToProto(*this, &pb_request)); + + std::string out; + if (!pb_request.SerializeToString(&out)) { + return Status::IOError("Serialized CloseSessionRequest exceeded 2GiB limit"); + } + return out; +} + +arrow::Result CloseSessionRequest::Deserialize( + std::string_view serialized) { + pb::CloseSessionRequest pb_request; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized CloseSessionRequest size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_request.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid CloseSessionRequest"); + } + CloseSessionRequest out; + RETURN_NOT_OK(internal::FromProto(pb_request, &out)); + return out; +} + +// CloseSessionResult + +std::string CloseSessionResult::ToString() const { + std::stringstream ss; + + ss << " CloseSessionResult::Deserialize( + std::string_view serialized) { + pb::CloseSessionResult pb_result; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized CloseSessionResult size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_result.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid CloseSessionResult"); + } + CloseSessionResult out; + RETURN_NOT_OK(internal::FromProto(pb_result, &out)); + return out; +} + Location::Location() { uri_ = std::make_shared(); } arrow::Result Location::Parse(const std::string& uri_string) { @@ -643,6 +991,21 @@ const ActionType ActionType::kRenewFlightEndpoint = "Extend expiration time of the given FlightEndpoint.\n" "Request Message: RenewFlightEndpointRequest\n" "Response Message: Renewed FlightEndpoint"}; +const ActionType ActionType::kSetSessionOptions = + ActionType{"SetSessionOptions", + "Set client session options by name/value pairs.\n" + "Request Message: SetSessionOptionsRequest\n" + "Response Message: SetSessionOptionsResult"}; +const ActionType ActionType::kGetSessionOptions = + ActionType{"GetSessionOptions", + "Get current client session options\n" + "Request Message: GetSessionOptionsRequest\n" + "Response Message: GetSessionOptionsResult"}; +const ActionType ActionType::kCloseSession = + ActionType{"CloseSession", + "Explicitly close/invalidate the cookie-specified client session.\n" + "Request Message: CloseSessionRequest\n" + "Response Message: CloseSessionResult"}; bool ActionType::Equals(const ActionType& other) const { return type == other.type && description == other.description; diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 40a0787d14a7a..0528c5deee3b5 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include "arrow/flight/type_fwd.h" @@ -184,6 +185,9 @@ struct ARROW_FLIGHT_EXPORT ActionType { static const ActionType kCancelFlightInfo; static const ActionType kRenewFlightEndpoint; + static const ActionType kSetSessionOptions; + static const ActionType kGetSessionOptions; + static const ActionType kCloseSession; }; /// \brief Opaque selection criteria for ListFlights RPC @@ -750,6 +754,199 @@ struct ARROW_FLIGHT_EXPORT CancelFlightInfoRequest { static arrow::Result Deserialize(std::string_view serialized); }; +/// \brief Variant supporting all possible value types for {Set,Get}SessionOptions +/// +/// By convention, an attempt to set a valueless (std::monostate) SessionOptionValue +/// should attempt to unset or clear the named option value on the server. +using SessionOptionValue = std::variant>; + +/// \brief The result of setting a session option. +enum class SetSessionOptionErrorValue : int8_t { + /// \brief The status of setting the option is unknown. + /// + /// Servers should avoid using this value (send a NOT_FOUND error if the requested + /// session is not known). Clients can retry the request. + kUnspecified, + /// \brief The given session option name is invalid. + kInvalidName, + /// \brief The session option value or type is invalid. + kInvalidValue, + /// \brief The session option cannot be set. + kError +}; +std::string ToString(const SetSessionOptionErrorValue& error_value); +std::ostream& operator<<(std::ostream& os, const SetSessionOptionErrorValue& error_value); + +/// \brief The result of closing a session. +enum class CloseSessionStatus : int8_t { + // \brief The session close status is unknown. + // + // Servers should avoid using this value (send a NOT_FOUND error if the requested + // session is not known). Clients can retry the request. + kUnspecified, + // \brief The session close request is complete. + // + // Subsequent requests with the same session produce a NOT_FOUND error. + kClosed, + // \brief The session close request is in progress. + // + // The client may retry the request. + kClosing, + // \brief The session is not closeable. + // + // The client should not retry the request. + kNotClosable +}; +std::string ToString(const CloseSessionStatus& status); +std::ostream& operator<<(std::ostream& os, const CloseSessionStatus& status); + +/// \brief A request to set a set of session options by name/value. +struct ARROW_FLIGHT_EXPORT SetSessionOptionsRequest { + std::map session_options; + + std::string ToString() const; + bool Equals(const SetSessionOptionsRequest& other) const; + + friend bool operator==(const SetSessionOptionsRequest& left, + const SetSessionOptionsRequest& right) { + return left.Equals(right); + } + friend bool operator!=(const SetSessionOptionsRequest& left, + const SetSessionOptionsRequest& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(std::string_view serialized); +}; + +/// \brief The result(s) of setting session option(s). +struct ARROW_FLIGHT_EXPORT SetSessionOptionsResult { + struct Error { + SetSessionOptionErrorValue value; + + bool Equals(const Error& other) const { return value == other.value; } + friend bool operator==(const Error& left, const Error& right) { + return left.Equals(right); + } + friend bool operator!=(const Error& left, const Error& right) { + return !(left == right); + } + }; + + std::map errors; + + std::string ToString() const; + bool Equals(const SetSessionOptionsResult& other) const; + + friend bool operator==(const SetSessionOptionsResult& left, + const SetSessionOptionsResult& right) { + return left.Equals(right); + } + friend bool operator!=(const SetSessionOptionsResult& left, + const SetSessionOptionsResult& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(std::string_view serialized); +}; + +/// \brief A request to get current session options. +struct ARROW_FLIGHT_EXPORT GetSessionOptionsRequest { + std::string ToString() const; + bool Equals(const GetSessionOptionsRequest& other) const; + + friend bool operator==(const GetSessionOptionsRequest& left, + const GetSessionOptionsRequest& right) { + return left.Equals(right); + } + friend bool operator!=(const GetSessionOptionsRequest& left, + const GetSessionOptionsRequest& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(std::string_view serialized); +}; + +/// \brief The current session options. +struct ARROW_FLIGHT_EXPORT GetSessionOptionsResult { + std::map session_options; + + std::string ToString() const; + bool Equals(const GetSessionOptionsResult& other) const; + + friend bool operator==(const GetSessionOptionsResult& left, + const GetSessionOptionsResult& right) { + return left.Equals(right); + } + friend bool operator!=(const GetSessionOptionsResult& left, + const GetSessionOptionsResult& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(std::string_view serialized); +}; + +/// \brief A request to close the open client session. +struct ARROW_FLIGHT_EXPORT CloseSessionRequest { + std::string ToString() const; + bool Equals(const CloseSessionRequest& other) const; + + friend bool operator==(const CloseSessionRequest& left, + const CloseSessionRequest& right) { + return left.Equals(right); + } + friend bool operator!=(const CloseSessionRequest& left, + const CloseSessionRequest& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(std::string_view serialized); +}; + +/// \brief The result of attempting to close the client session. +struct ARROW_FLIGHT_EXPORT CloseSessionResult { + CloseSessionStatus status; + + std::string ToString() const; + bool Equals(const CloseSessionResult& other) const; + + friend bool operator==(const CloseSessionResult& left, + const CloseSessionResult& right) { + return left.Equals(right); + } + friend bool operator!=(const CloseSessionResult& left, + const CloseSessionResult& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(std::string_view serialized); +}; + /// \brief An iterator to FlightInfo instances returned by ListFlights. class ARROW_FLIGHT_EXPORT FlightListing { public: diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index eb2e26951cd88..f3ef54744806d 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -606,6 +606,11 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, "RenewFlightEndpoint are working as expected."), skip_testers={"JS", "C#", "Rust"}, ), + Scenario( + "session_options", + description="Ensure Flight SQL Sessions work as expected.", + skip_testers={"JS", "C#", "Rust", "Go"} + ), Scenario( "poll_flight_info", description="Ensure PollFlightInfo is supported.", diff --git a/docs/source/format/FlightSql.rst b/docs/source/format/FlightSql.rst index f7521c3876493..fd8f54ca7f130 100644 --- a/docs/source/format/FlightSql.rst +++ b/docs/source/format/FlightSql.rst @@ -149,6 +149,47 @@ the ``type`` should be ``ClosePreparedStatement``). When used with DoPut: execute the query and return the number of affected rows. +Flight Server Session Management +-------------------------------- + +Flight SQL provides commands to set and update server session variables +which affect the server behaviour in various ways. Common options may +include (depending on the server implementation) ``catalog`` and +``schema``, indicating the currently-selected catalog and schema for +queries to be run against. + +Clients should prefer, where possible, setting options prior to issuing +queries and other commands, as some server implementations may require +these options be set exactly once and prior to any other activity which +may trigger their implicit setting. + +For compatibility with Database Connectivity drivers (JDBC, ODBC, and +others), it is strongly recommended that server implementations accept +string representations of all option values which may be provided to the +driver as part of a server connection string and passed through to the +server without further conversion. For ease of use it is also recommended +to accept and convert other numeric types to the preferred type for an +option value, however this is not required. + +Sessions are persisted between the client and server using an +implementation-defined mechanism, which is typically RFC 6265 cookies. +Servers may also combine other connection state opaquely with the +session token: Consider that the lifespan and semantics of a session +should make sense for any additional uses, e.g. CloseSession would also +invalidate any authentication context persisted via the session context. +A session may be initiated upon a nonempty (or empty) SetSessionOptions +call, or at any other time of the server's choosing. + +``SetSessionOptions`` +Set server session option(s) by name/value. + +``GetSessionOptions`` +Get the current server session options, including those set by the client +and any defaulted or implicitly set by the server. + +``CloseSession`` +Close and invalidate the current session context. + Sequence Diagrams ================= diff --git a/format/Flight.proto b/format/Flight.proto index de3794f05ba83..59714108e1cbc 100644 --- a/format/Flight.proto +++ b/format/Flight.proto @@ -525,3 +525,117 @@ message FlightData { message PutResult { bytes app_metadata = 1; } + +/* + * EXPERIMENTAL: Union of possible value types for a Session Option to be set to. + * + * By convention, an attempt to set a valueless SessionOptionValue should + * attempt to unset or clear the named option value on the server. + */ +message SessionOptionValue { + message StringListValue { + repeated string values = 1; + } + + oneof option_value { + string string_value = 1; + bool bool_value = 2; + sfixed64 int64_value = 3; + double double_value = 4; + StringListValue string_list_value = 5; + } +} + +/* + * EXPERIMENTAL: A request to set session options for an existing or new (implicit) + * server session. + * + * Sessions are persisted and referenced via a transport-level state management, typically + * RFC 6265 HTTP cookies when using an HTTP transport. The suggested cookie name or state + * context key is 'arrow_flight_session_id', although implementations may freely choose their + * own name. + * + * Session creation (if one does not already exist) is implied by this RPC request, however + * server implementations may choose to initiate a session that also contains client-provided + * session options at any other time, e.g. on authentication, or when any other call is made + * and the server wishes to use a session to persist any state (or lack thereof). + */ +message SetSessionOptionsRequest { + map session_options = 1; +} + +/* + * EXPERIMENTAL: The results (individually) of setting a set of session options. + * + * Option names should only be present in the response if they were not successfully + * set on the server; that is, a response without an Error for a name provided in the + * SetSessionOptionsRequest implies that the named option value was set successfully. + */ +message SetSessionOptionsResult { + enum ErrorValue { + // Protobuf deserialization fallback value: The status is unknown or unrecognized. + // Servers should avoid using this value. The request may be retried by the client. + UNSPECIFIED = 0; + // The given session option name is invalid. + INVALID_NAME = 1; + // The session option value or type is invalid. + INVALID_VALUE = 2; + // The session option cannot be set. + ERROR = 3; + } + + message Error { + ErrorValue value = 1; + } + + map errors = 1; +} + +/* + * EXPERIMENTAL: A request to access the session options for the current server session. + * + * The existing session is referenced via a cookie header or similar (see + * SetSessionOptionsRequest above); it is an error to make this request with a missing, + * invalid, or expired session cookie header or other implementation-defined session + * reference token. + */ +message GetSessionOptionsRequest { +} + +/* + * EXPERIMENTAL: The result containing the current server session options. + */ +message GetSessionOptionsResult { + map session_options = 1; +} + +/* + * Request message for the "Close Session" action. + * + * The exiting session is referenced via a cookie header. + */ +message CloseSessionRequest { +} + +/* + * The result of closing a session. + */ +message CloseSessionResult { + enum Status { + // Protobuf deserialization fallback value: The session close status is unknown or + // not recognized. Servers should avoid using this value (send a NOT_FOUND error if + // the requested session is not known or expired). Clients can retry the request. + UNSPECIFIED = 0; + // The session close request is complete. Subsequent requests with + // the same session produce a NOT_FOUND error. + CLOSED = 1; + // The session close request is in progress. The client may retry + // the close request. + CLOSING = 2; + // The session is not closeable. The client should not retry the + // close request. + NOT_CLOSEABLE = 3; + } + + Status status = 1; +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CloseSessionRequest.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CloseSessionRequest.java new file mode 100644 index 0000000000000..29eb3664f6286 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CloseSessionRequest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.arrow.flight.impl.Flight; + +/** A request to close/invalidate a server session context. */ +public class CloseSessionRequest { + public CloseSessionRequest() { + } + + CloseSessionRequest(Flight.CloseSessionRequest proto) { + } + + Flight.CloseSessionRequest toProtocol() { + return Flight.CloseSessionRequest.getDefaultInstance(); + } + + /** + * Get the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the message, as returned by {@link #serialize()}. + * @return The deserialized message. + * @throws IOException if the serialized form is invalid. + */ + public static CloseSessionRequest deserialize(ByteBuffer serialized) throws IOException { + return new CloseSessionRequest(Flight.CloseSessionRequest.parseFrom(serialized)); + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CloseSessionResult.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CloseSessionResult.java new file mode 100644 index 0000000000000..c3710a14b108a --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CloseSessionResult.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.arrow.flight.impl.Flight; + +/** The result of attempting to close/invalidate a server session context. */ +public class CloseSessionResult { + /** + * Close operation result status values. + */ + public enum Status { + /** + * The session close status is unknown. Servers should avoid using this value + * (send a NOT_FOUND error if the requested session is not known). Clients can + * retry the request. + */ + UNSPECIFIED, + /** + * The session close request is complete. + */ + CLOSED, + /** + * The session close request is in progress. The client may retry the request. + */ + CLOSING, + /** + * The session is not closeable. + */ + NOT_CLOSABLE, + ; + + public static Status fromProtocol(Flight.CloseSessionResult.Status proto) { + return values()[proto.getNumber()]; + } + + public Flight.CloseSessionResult.Status toProtocol() { + return Flight.CloseSessionResult.Status.values()[ordinal()]; + } + } + + private final Status status; + + public CloseSessionResult(Status status) { + this.status = status; + } + + CloseSessionResult(Flight.CloseSessionResult proto) { + status = Status.fromProtocol(proto.getStatus()); + if (status == null) { + // Unreachable + throw new IllegalArgumentException(""); + } + } + + public Status getStatus() { + return status; + } + + Flight.CloseSessionResult toProtocol() { + Flight.CloseSessionResult.Builder b = Flight.CloseSessionResult.newBuilder(); + b.setStatus(status.toProtocol()); + return b.build(); + } + + /** + * Get the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the message, as returned by {@link #serialize()}. + * @return The deserialized message. + * @throws IOException if the serialized form is invalid. + */ + public static CloseSessionResult deserialize(ByteBuffer serialized) throws IOException { + return new CloseSessionResult(Flight.CloseSessionResult.parseFrom(serialized)); + } + +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index 422ed01c394d1..6310f21d574df 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -568,6 +568,102 @@ public FlightEndpoint renewFlightEndpoint(RenewFlightEndpointRequest request, Ca return result; } + /** + * Set server session option(s) by name/value. + * + * Sessions are generally persisted via HTTP cookies. + * + * @param request The session options to set on the server. + * @param options Call options. + * @return The result containing per-value error statuses, if any. + */ + public SetSessionOptionsResult setSessionOptions(SetSessionOptionsRequest request, CallOption... options) { + Action action = new Action(FlightConstants.SET_SESSION_OPTIONS.getType(), request.serialize().array()); + Iterator results = doAction(action, options); + if (!results.hasNext()) { + throw CallStatus.INTERNAL + .withDescription("Server did not return a response") + .toRuntimeException(); + } + + SetSessionOptionsResult result; + try { + result = SetSessionOptionsResult.deserialize(ByteBuffer.wrap(results.next().getBody())); + } catch (IOException e) { + throw CallStatus.INTERNAL + .withDescription("Failed to parse server response: " + e) + .withCause(e) + .toRuntimeException(); + } + results.forEachRemaining((ignored) -> { + }); + return result; + } + + /** + * Get the current server session options. + * + * The session is generally accessed via an HTTP cookie. + * + * @param request The (empty) GetSessionOptionsRequest. + * @param options Call options. + * @return The result containing the set of session options configured on the server. + */ + public GetSessionOptionsResult getSessionOptions(GetSessionOptionsRequest request, CallOption... options) { + Action action = new Action(FlightConstants.GET_SESSION_OPTIONS.getType(), request.serialize().array()); + Iterator results = doAction(action, options); + if (!results.hasNext()) { + throw CallStatus.INTERNAL + .withDescription("Server did not return a response") + .toRuntimeException(); + } + + GetSessionOptionsResult result; + try { + result = GetSessionOptionsResult.deserialize(ByteBuffer.wrap(results.next().getBody())); + } catch (IOException e) { + throw CallStatus.INTERNAL + .withDescription("Failed to parse server response: " + e) + .withCause(e) + .toRuntimeException(); + } + results.forEachRemaining((ignored) -> { + }); + return result; + } + + /** + * Close/invalidate the current server session. + * + * The session is generally accessed via an HTTP cookie. + * + * @param request The (empty) CloseSessionRequest. + * @param options Call options. + * @return The result containing the status of the close operation. + */ + public CloseSessionResult closeSession(CloseSessionRequest request, CallOption... options) { + Action action = new Action(FlightConstants.CLOSE_SESSION.getType(), request.serialize().array()); + Iterator results = doAction(action, options); + if (!results.hasNext()) { + throw CallStatus.INTERNAL + .withDescription("Server did not return a response") + .toRuntimeException(); + } + + CloseSessionResult result; + try { + result = CloseSessionResult.deserialize(ByteBuffer.wrap(results.next().getBody())); + } catch (IOException e) { + throw CallStatus.INTERNAL + .withDescription("Failed to parse server response: " + e) + .withCause(e) + .toRuntimeException(); + } + results.forEachRemaining((ignored) -> { + }); + return result; + } + /** * Interface for writers to an Arrow data stream. */ diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java index 2a240abad6d95..4456e3dae4949 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java @@ -35,4 +35,18 @@ public interface FlightConstants { "Extend expiration time of the given FlightEndpoint.\n" + "Request Message: RenewFlightEndpointRequest\n" + "Response Message: Renewed FlightEndpoint"); + + ActionType SET_SESSION_OPTIONS = new ActionType("SetSessionOptions", + "Set client session options by name/value pairs.\n" + + "Request Message: SetSessionOptionsRequest\n" + + "Response Message: SetSessionOptionsResult"); + + ActionType GET_SESSION_OPTIONS = new ActionType("GetSessionOptions", + "Get current client session options\n" + + "Request Message: GetSessionOptionsRequest\n" + + "Response Message: GetSessionOptionsResult"); + ActionType CLOSE_SESSION = new ActionType("CloseSession", + "Explicitly close/invalidate the cookie-specified client session.\n" + + "Request Message: CloseSessionRequest\n" + + "Response Message: CloseSessionResult"); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/GetSessionOptionsRequest.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/GetSessionOptionsRequest.java new file mode 100644 index 0000000000000..9d63e59027aac --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/GetSessionOptionsRequest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.arrow.flight.impl.Flight; + +/** + * A request to get current session options. + */ +public class GetSessionOptionsRequest { + public GetSessionOptionsRequest() { + } + + GetSessionOptionsRequest(Flight.GetSessionOptionsRequest proto) { + } + + Flight.GetSessionOptionsRequest toProtocol() { + return Flight.GetSessionOptionsRequest.getDefaultInstance(); + } + + /** + * Get the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the message, as returned by {@link #serialize()}. + * @return The deserialized message. + * @throws IOException if the serialized form is invalid. + */ + public static GetSessionOptionsRequest deserialize(ByteBuffer serialized) throws IOException { + return new GetSessionOptionsRequest(Flight.GetSessionOptionsRequest.parseFrom(serialized)); + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/GetSessionOptionsResult.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/GetSessionOptionsResult.java new file mode 100644 index 0000000000000..c777bd39bd032 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/GetSessionOptionsResult.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.impl.Flight; + +/** A request to view the currently-set options for the current server session. */ +public class GetSessionOptionsResult { + private final Map sessionOptions; + + public GetSessionOptionsResult(Map sessionOptions) { + this.sessionOptions = Collections.unmodifiableMap(new HashMap(sessionOptions)); + } + + GetSessionOptionsResult(Flight.GetSessionOptionsResult proto) { + sessionOptions = Collections.unmodifiableMap( + proto.getSessionOptionsMap().entrySet().stream().collect(Collectors.toMap( + Map.Entry::getKey, (e) -> SessionOptionValueFactory.makeSessionOptionValue(e.getValue())))); + } + + /** + * Get the session options map contained in the request. + * + * @return An immutable view of the session options map. + */ + public Map getSessionOptions() { + return sessionOptions; + } + + Flight.GetSessionOptionsResult toProtocol() { + Flight.GetSessionOptionsResult.Builder b = Flight.GetSessionOptionsResult.newBuilder(); + b.putAllSessionOptions(sessionOptions.entrySet().stream().collect(Collectors.toMap( + Map.Entry::getKey, (e) -> e.getValue().toProtocol()))); + return b.build(); + } + + /** + * Get the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the message, as returned by {@link #serialize()}. + * @return The deserialized message. + * @throws IOException if the serialized form is invalid. + */ + public static GetSessionOptionsResult deserialize(ByteBuffer serialized) throws IOException { + return new GetSessionOptionsResult(Flight.GetSessionOptionsResult.parseFrom(serialized)); + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpSessionOptionValueVisitor.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpSessionOptionValueVisitor.java new file mode 100644 index 0000000000000..c951cce0ed42d --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpSessionOptionValueVisitor.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * A helper to facilitate easier anonymous subclass declaration. + * + * Implementations need only override callbacks for types they wish to do something with. + * + * @param Return type of the visit operation. + */ +public class NoOpSessionOptionValueVisitor implements SessionOptionValueVisitor { + /** + * A callback to handle SessionOptionValue containing a String. + */ + public T visit(String value) { + return null; + } + + /** + * A callback to handle SessionOptionValue containing a boolean. + */ + public T visit(boolean value) { + return null; + } + + /** + * A callback to handle SessionOptionValue containing a long. + */ + public T visit(long value) { + return null; + } + + /** + * A callback to handle SessionOptionValue containing a double. + */ + public T visit(double value) { + return null; + } + + /** + * A callback to handle SessionOptionValue containing an array of String. + */ + public T visit(String[] value) { + return null; + } + + /** + * A callback to handle SessionOptionValue containing no value. + * + * By convention, an attempt to set a valueless SessionOptionValue should + * attempt to unset or clear the named option value on the server. + */ + public T visit(Void value) { + return null; + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerSessionMiddleware.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerSessionMiddleware.java new file mode 100644 index 0000000000000..7091caa5e98bc --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerSessionMiddleware.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * Middleware for handling Flight SQL Sessions including session cookie handling. + * + * Currently experimental. + */ +public class ServerSessionMiddleware implements FlightServerMiddleware { + Factory factory; + boolean existingSession; + private Session session; + private String closedSessionId = null; + + public static final String sessionCookieName = "arrow_flight_session_id"; + + /** + * Factory for managing and accessing ServerSessionMiddleware. + */ + public static class Factory implements FlightServerMiddleware.Factory { + private final ConcurrentMap sessionStore = + new ConcurrentHashMap<>(); + private final Callable idGenerator; + + /** + * Construct a factory for ServerSessionMiddleware. + * + * Factory manages and accesses persistent sessions based on HTTP cookies. + * + * @param idGenerator A Callable returning unique session id Strings. + */ + public Factory(Callable idGenerator) { + this.idGenerator = idGenerator; + } + + private synchronized Session createNewSession() { + String id; + try { + id = idGenerator.call(); + } catch (Exception ignored) { + // Most impls aren't going to throw so don't make caller handle a nonexistent checked exception + throw CallStatus.INTERNAL.withDescription("Session creation error").toRuntimeException(); + } + + Session newSession = new Session(id); + if (sessionStore.putIfAbsent(id, newSession) != null) { + // Collision, should never happen + throw CallStatus.INTERNAL.withDescription("Session creation error").toRuntimeException(); + } + return newSession; + } + + private void closeSession(String id) { + if (sessionStore.remove(id) == null) { + throw CallStatus.NOT_FOUND.withDescription("Session id '" + id + "' not found.").toRuntimeException(); + } + } + + @Override + public ServerSessionMiddleware onCallStarted(CallInfo callInfo, CallHeaders incomingHeaders, + RequestContext context) { + String sessionId = null; + + final Iterable it = incomingHeaders.getAll("cookie"); + if (it != null) { + findIdCookie: + for (final String headerValue : it) { + for (final String cookie : headerValue.split(" ;")) { + final String[] cookiePair = cookie.split("="); + if (cookiePair.length != 2) { + // Soft failure: Ignore invalid cookie list field + break; + } + + if (sessionCookieName.equals(cookiePair[0]) && cookiePair[1].length() > 0) { + sessionId = cookiePair[1]; + break findIdCookie; + } + } + } + } + + if (sessionId == null) { + // No session cookie, create middleware instance without session. + return new ServerSessionMiddleware(this, incomingHeaders, null); + } + + Session session = sessionStore.get(sessionId); + // Cookie provided by caller, but invalid + if (session == null) { + // Can't soft-fail/proceed here, clients will get unexpected behaviour without options they thought were set. + throw CallStatus.NOT_FOUND.withDescription("Invalid " + sessionCookieName + " cookie.").toRuntimeException(); + } + + return new ServerSessionMiddleware(this, incomingHeaders, session); + } + } + + /** + * A thread-safe container for named SessionOptionValues. + */ + public static class Session { + public final String id; + private ConcurrentMap sessionData = + new ConcurrentHashMap(); + + /** + * Construct a new Session with the given id. + * + * @param id The Session's id string, which is used as the session cookie value. + */ + private Session(String id) { + this.id = id; + } + + /** Get session option by name, or null if it does not exist. */ + public SessionOptionValue getSessionOption(String name) { + return sessionData.get(name); + } + + /** Get an immutable copy of the session options map. */ + public Map getSessionOptions() { + return Collections.unmodifiableMap(new HashMap(sessionData)); + } + + /** Set session option by name to given value. */ + public void setSessionOption(String name, SessionOptionValue value) { + sessionData.put(name, value); + } + + /** Idempotently remove name from this session. */ + public void eraseSessionOption(String name) { + sessionData.remove(name); + } + } + + private final CallHeaders headers; + + private ServerSessionMiddleware(ServerSessionMiddleware.Factory factory, + CallHeaders incomingHeaders, Session session) { + this.factory = factory; + headers = incomingHeaders; + this.session = session; + existingSession = (session != null); + } + + /** + * Check if there is an open session associated with this call. + * + * @return True iff there is an open session associated with this call. + */ + public boolean hasSession() { + return session != null; + } + + /** + * Get the existing or new session value map for this call. + * + * @return The session option value map, or null in case of an id generation collision. + */ + public synchronized Session getSession() { + if (session == null) { + session = factory.createNewSession(); + } + + return session; + } + + /** + * Close the current session. + * + * It is an error to call this without a valid session specified via cookie or equivalent. + * */ + public synchronized void closeSession() { + if (session == null) { + throw CallStatus.NOT_FOUND.withDescription("No session found for the current call.").toRuntimeException(); + } + factory.closeSession(session.id); + closedSessionId = session.id; + session = null; + } + + public CallHeaders getCallHeaders() { + return headers; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + if (!existingSession && session != null) { + outgoingHeaders.insert("set-cookie", sessionCookieName + "=" + session.id); + } + if (closedSessionId != null) { + outgoingHeaders.insert("set-cookie", sessionCookieName + "=" + closedSessionId + "; Max-Age=0"); + } + } + + @Override + public void onCallCompleted(CallStatus status) { + } + + @Override + public void onCallErrored(Throwable err) { + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValue.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValue.java new file mode 100644 index 0000000000000..db22c736be182 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValue.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Arrays; + +import org.apache.arrow.flight.impl.Flight; + +/** + * A union-like container interface for supported session option value types. + */ +public abstract class SessionOptionValue { + SessionOptionValue() { + } + + /** + * Value access via a caller-provided visitor/functor. + */ + public abstract T acceptVisitor(SessionOptionValueVisitor v); + + Flight.SessionOptionValue toProtocol() { + Flight.SessionOptionValue.Builder b = Flight.SessionOptionValue.newBuilder(); + SessionOptionValueToProtocolVisitor visitor = new SessionOptionValueToProtocolVisitor(b); + this.acceptVisitor(visitor); + return b.build(); + } + + /** Check whether the SessionOptionValue is empty/valueless. */ + public boolean isEmpty() { + return false; + } + + private class SessionOptionValueToProtocolVisitor implements SessionOptionValueVisitor { + final Flight.SessionOptionValue.Builder b; + + SessionOptionValueToProtocolVisitor(Flight.SessionOptionValue.Builder b) { + this.b = b; + } + + @Override + public Void visit(String value) { + b.setStringValue(value); + return null; + } + + @Override + public Void visit(boolean value) { + b.setBoolValue(value); + return null; + } + + @Override + public Void visit(long value) { + b.setInt64Value(value); + return null; + } + + @Override + public Void visit(double value) { + b.setDoubleValue(value); + return null; + } + + @Override + public Void visit(String[] value) { + Flight.SessionOptionValue.StringListValue.Builder pbSLVBuilder = + Flight.SessionOptionValue.StringListValue.newBuilder(); + pbSLVBuilder.addAllValues(Arrays.asList(value)); + b.setStringListValue(pbSLVBuilder.build()); + return null; + } + + @Override + public Void visit(Void ignored) { + b.clearOptionValue(); + return null; + } + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValueFactory.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValueFactory.java new file mode 100644 index 0000000000000..47c82fa7bb7fd --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValueFactory.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +import org.apache.arrow.flight.impl.Flight; + +/** Abstract factory for concrete SessionOptionValue instances. */ +public class SessionOptionValueFactory { + public static SessionOptionValue makeSessionOptionValue(String value) { + return new SessionOptionValueString(value); + } + + public static SessionOptionValue makeSessionOptionValue(boolean value) { + return new SessionOptionValueBoolean(value); + } + + public static SessionOptionValue makeSessionOptionValue(long value) { + return new SessionOptionValueLong(value); + } + + public static SessionOptionValue makeSessionOptionValue(double value) { + return new SessionOptionValueDouble(value); + } + + public static SessionOptionValue makeSessionOptionValue(String[] value) { + return new SessionOptionValueStringList(value); + } + + public static SessionOptionValue makeEmptySessionOptionValue() { + return new SessionOptionValueEmpty(); + } + + /** Construct a SessionOptionValue from its Protobuf object representation. */ + public static SessionOptionValue makeSessionOptionValue(Flight.SessionOptionValue proto) { + switch (proto.getOptionValueCase()) { + case STRING_VALUE: + return new SessionOptionValueString(proto.getStringValue()); + case BOOL_VALUE: + return new SessionOptionValueBoolean(proto.getBoolValue()); + case INT64_VALUE: + return new SessionOptionValueLong(proto.getInt64Value()); + case DOUBLE_VALUE: + return new SessionOptionValueDouble(proto.getDoubleValue()); + case STRING_LIST_VALUE: + // Using ByteString::toByteArray() here otherwise we still somehow get `ByteArray`s with broken .equals(String) + return new SessionOptionValueStringList(proto.getStringListValue().getValuesList().asByteStringList().stream() + .map((e) -> new String(e.toByteArray(), StandardCharsets.UTF_8)).toArray(String[]::new)); + case OPTIONVALUE_NOT_SET: + return new SessionOptionValueEmpty(); + default: + // Unreachable + throw new IllegalArgumentException(""); + } + } + + private static class SessionOptionValueString extends SessionOptionValue { + private final String value; + + SessionOptionValueString(String value) { + this.value = value; + } + + @Override + public T acceptVisitor(SessionOptionValueVisitor v) { + return v.visit(value); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SessionOptionValueString that = (SessionOptionValueString) o; + return value.equals(that.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + @Override + public String toString() { + return '"' + value + '"'; + } + } + + private static class SessionOptionValueBoolean extends SessionOptionValue { + private final boolean value; + + SessionOptionValueBoolean(boolean value) { + this.value = value; + } + + @Override + public T acceptVisitor(SessionOptionValueVisitor v) { + return v.visit(value); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SessionOptionValueBoolean that = (SessionOptionValueBoolean) o; + return value == that.value; + } + + @Override + public int hashCode() { + return Boolean.hashCode(value); + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + private static class SessionOptionValueLong extends SessionOptionValue { + private final long value; + + SessionOptionValueLong(long value) { + this.value = value; + } + + @Override + public T acceptVisitor(SessionOptionValueVisitor v) { + return v.visit(value); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SessionOptionValueLong that = (SessionOptionValueLong) o; + return value == that.value; + } + + @Override + public int hashCode() { + return Long.hashCode(value); + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + private static class SessionOptionValueDouble extends SessionOptionValue { + private final double value; + + SessionOptionValueDouble(double value) { + this.value = value; + } + + @Override + public T acceptVisitor(SessionOptionValueVisitor v) { + return v.visit(value); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SessionOptionValueDouble that = (SessionOptionValueDouble) o; + return value == that.value; + } + + @Override + public int hashCode() { + return Double.hashCode(value); + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + private static class SessionOptionValueStringList extends SessionOptionValue { + private final String[] value; + + SessionOptionValueStringList(String[] value) { + this.value = value.clone(); + } + + @Override + public T acceptVisitor(SessionOptionValueVisitor v) { + return v.visit(value); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SessionOptionValueStringList that = (SessionOptionValueStringList) o; + return Arrays.deepEquals(value, that.value); + } + + @Override + public int hashCode() { + return Arrays.deepHashCode(value); + } + + @Override + public String toString() { + if (value.length == 0) { + return "[]"; + } + return "[\"" + String.join("\", \"", value) + "\"]"; + } + } + + private static class SessionOptionValueEmpty extends SessionOptionValue { + @Override + public T acceptVisitor(SessionOptionValueVisitor v) { + return v.visit((Void) null); + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + return true; + } + + @Override + public int hashCode() { + return SessionOptionValueEmpty.class.hashCode(); + } + + @Override + public String toString() { + return ""; + } + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValueVisitor.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValueVisitor.java new file mode 100644 index 0000000000000..f2178224a0d29 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SessionOptionValueVisitor.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * A visitor interface to access SessionOptionValue's contained value. + * + * @param Return type of the visit operation. + */ +public interface SessionOptionValueVisitor { + /** + * A callback to handle SessionOptionValue containing a String. + */ + T visit(String value); + + /** + * A callback to handle SessionOptionValue containing a boolean. + */ + T visit(boolean value); + + /** + * A callback to handle SessionOptionValue containing a long. + */ + T visit(long value); + + /** + * A callback to handle SessionOptionValue containing a double. + */ + T visit(double value); + + /** + * A callback to handle SessionOptionValue containing an array of String. + */ + T visit(String[] value); + + /** + * A callback to handle SessionOptionValue containing no value. + * + * By convention, an attempt to set a valueless SessionOptionValue should + * attempt to unset or clear the named option value on the server. + */ + T visit(Void value); +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SetSessionOptionsRequest.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SetSessionOptionsRequest.java new file mode 100644 index 0000000000000..8a5253e682162 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SetSessionOptionsRequest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.impl.Flight; + +/** A request to set option(s) in an existing or implicitly-created server session. */ +public class SetSessionOptionsRequest { + private final Map sessionOptions; + + public SetSessionOptionsRequest(Map sessionOptions) { + this.sessionOptions = Collections.unmodifiableMap(new HashMap(sessionOptions)); + } + + SetSessionOptionsRequest(Flight.SetSessionOptionsRequest proto) { + sessionOptions = Collections.unmodifiableMap( + proto.getSessionOptionsMap().entrySet().stream().collect(Collectors.toMap( + Map.Entry::getKey, (e) -> SessionOptionValueFactory.makeSessionOptionValue(e.getValue())))); + } + + /** + * Get the session option map from the request. + * + * @return An immutable view of the session options map. + */ + public Map getSessionOptions() { + return Collections.unmodifiableMap(sessionOptions); + } + + Flight.SetSessionOptionsRequest toProtocol() { + Flight.SetSessionOptionsRequest.Builder b = Flight.SetSessionOptionsRequest.newBuilder(); + b.putAllSessionOptions(sessionOptions.entrySet().stream().collect(Collectors.toMap( + Map.Entry::getKey, (e) -> e.getValue().toProtocol()))); + return b.build(); + } + + /** + * Get the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the message, as returned by {@link #serialize()}. + * @return The deserialized message. + * @throws IOException if the serialized form is invalid. + */ + public static SetSessionOptionsRequest deserialize(ByteBuffer serialized) throws IOException { + return new SetSessionOptionsRequest(Flight.SetSessionOptionsRequest.parseFrom(serialized)); + } + +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SetSessionOptionsResult.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SetSessionOptionsResult.java new file mode 100644 index 0000000000000..14d53cc6767e0 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SetSessionOptionsResult.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.impl.Flight; + +/** The result of attempting to set a set of session options. */ +public class SetSessionOptionsResult { + /** Error status value for per-option errors. */ + public enum ErrorValue { + /** + * The status of setting the option is unknown. Servers should avoid using this value + * (send a NOT_FOUND error if the requested session is not known). Clients can retry + * the request. + */ + UNSPECIFIED, + /** + * The given session option name is invalid. + */ + INVALID_NAME, + /** + * The session option value or type is invalid. + */ + INVALID_VALUE, + /** + * The session option cannot be set. + */ + ERROR, + ; + + static ErrorValue fromProtocol(Flight.SetSessionOptionsResult.ErrorValue s) { + return values()[s.getNumber()]; + } + + Flight.SetSessionOptionsResult.ErrorValue toProtocol() { + return Flight.SetSessionOptionsResult.ErrorValue.values()[ordinal()]; + } + } + + /** Per-option extensible error response container. */ + public static class Error { + public ErrorValue value; + + public Error(ErrorValue value) { + this.value = value; + } + + Error(Flight.SetSessionOptionsResult.Error e) { + value = ErrorValue.fromProtocol(e.getValue()); + } + + Flight.SetSessionOptionsResult.Error toProtocol() { + Flight.SetSessionOptionsResult.Error.Builder b = Flight.SetSessionOptionsResult.Error.newBuilder(); + b.setValue(value.toProtocol()); + return b.build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Error that = (Error) o; + return value == that.value; + } + + @Override + public int hashCode() { + return value.hashCode(); + } + } + + private final Map errors; + + public SetSessionOptionsResult(Map errors) { + this.errors = Collections.unmodifiableMap(new HashMap(errors)); + } + + SetSessionOptionsResult(Flight.SetSessionOptionsResult proto) { + errors = Collections.unmodifiableMap(proto.getErrors().entrySet().stream().collect( + Collectors.toMap(Map.Entry::getKey, (e) -> new Error(e.getValue())))); + } + + /** Report whether the error map has nonzero length. */ + public boolean hasErrors() { + return errors.size() > 0; + } + + /** + * Get the error status map from the result object. + * + * @return An immutable view of the error status map. + */ + public Map getErrors() { + return errors; + } + + Flight.SetSessionOptionsResult toProtocol() { + Flight.SetSessionOptionsResult.Builder b = Flight.SetSessionOptionsResult.newBuilder(); + b.putAllErrors(errors.entrySet().stream().collect(Collectors.toMap( + Map.Entry::getKey, + (e) -> e.getValue().toProtocol()))); + return b.build(); + } + + /** + * Get the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + *

Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the message, as returned by {@link #serialize()}. + * @return The deserialized message. + * @throws IOException if the serialized form is invalid. + */ + public static SetSessionOptionsResult deserialize(ByteBuffer serialized) throws IOException { + return new SetSessionOptionsResult(Flight.SetSessionOptionsResult.parseFrom(serialized)); + } +} diff --git a/java/flight/flight-integration-tests/pom.xml b/java/flight/flight-integration-tests/pom.xml index 46a3fe34ea203..e3ffe0fb8961e 100644 --- a/java/flight/flight-integration-tests/pom.xml +++ b/java/flight/flight-integration-tests/pom.xml @@ -49,6 +49,10 @@ com.google.protobuf protobuf-java + + com.google.guava + guava + commons-cli commons-cli diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index 26629c650e30f..6878c22c5ccdc 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -49,6 +49,8 @@ private Scenarios() { scenarios.put("poll_flight_info", PollFlightInfoScenario::new); scenarios.put("flight_sql", FlightSqlScenario::new); scenarios.put("flight_sql:extension", FlightSqlExtensionScenario::new); + scenarios.put("app_metadata_flight_info_endpoint", AppMetadataFlightInfoEndpointScenario::new); + scenarios.put("session_options", SessionOptionsScenario::new); } private static Scenarios getInstance() { diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/SessionOptionsProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/SessionOptionsProducer.java new file mode 100644 index 0000000000000..f29028547c452 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/SessionOptionsProducer.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.CloseSessionRequest; +import org.apache.arrow.flight.CloseSessionResult; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.GetSessionOptionsRequest; +import org.apache.arrow.flight.GetSessionOptionsResult; +import org.apache.arrow.flight.ServerSessionMiddleware; +import org.apache.arrow.flight.SessionOptionValue; +import org.apache.arrow.flight.SessionOptionValueFactory; +import org.apache.arrow.flight.SetSessionOptionsRequest; +import org.apache.arrow.flight.SetSessionOptionsResult; +import org.apache.arrow.flight.sql.NoOpFlightSqlProducer; + +/** The server used for testing Sessions. + *

+ * SetSessionOptions(), GetSessionOptions(), and CloseSession() operate on a + * simple SessionOptionValue store. + */ +final class SessionOptionsProducer extends NoOpFlightSqlProducer { + private static final SessionOptionValue invalidOptionValue = + SessionOptionValueFactory.makeSessionOptionValue("lol_invalid"); + private final FlightServerMiddleware.Key sessionMiddlewareKey; + + SessionOptionsProducer(FlightServerMiddleware.Key sessionMiddlewareKey) { + this.sessionMiddlewareKey = sessionMiddlewareKey; + } + + @Override + public void setSessionOptions(SetSessionOptionsRequest request, CallContext context, + StreamListener listener) { + Map errors = new HashMap(); + + ServerSessionMiddleware middleware = context.getMiddleware(sessionMiddlewareKey); + ServerSessionMiddleware.Session session = middleware.getSession(); + for (Map.Entry entry : request.getSessionOptions().entrySet()) { + // Blacklisted option name + if (entry.getKey().equals("lol_invalid")) { + errors.put(entry.getKey(), + new SetSessionOptionsResult.Error(SetSessionOptionsResult.ErrorValue.INVALID_NAME)); + continue; + } + // Blacklisted option value + // Recommend using a visitor to check polymorphic equality, but this check is easy + if (entry.getValue().equals(invalidOptionValue)) { + errors.put(entry.getKey(), + new SetSessionOptionsResult.Error(SetSessionOptionsResult.ErrorValue.INVALID_VALUE)); + continue; + } + // Business as usual: + if (entry.getValue().isEmpty()) { + session.eraseSessionOption(entry.getKey()); + continue; + } + session.setSessionOption(entry.getKey(), entry.getValue()); + } + listener.onNext(new SetSessionOptionsResult(errors)); + listener.onCompleted(); + } + + @Override + public void getSessionOptions(GetSessionOptionsRequest request, CallContext context, + StreamListener listener) { + ServerSessionMiddleware middleware = context.getMiddleware(sessionMiddlewareKey); + if (!middleware.hasSession()) { + // Attempt to get options without an existing session + listener.onError(CallStatus.NOT_FOUND.withDescription("No current server session").toRuntimeException()); + return; + } + final Map sessionOptions = middleware.getSession().getSessionOptions(); + listener.onNext(new GetSessionOptionsResult(sessionOptions)); + listener.onCompleted(); + } + + @Override + public void closeSession(CloseSessionRequest request, CallContext context, + StreamListener listener) { + ServerSessionMiddleware middleware = context.getMiddleware(sessionMiddlewareKey); + try { + middleware.closeSession(); + } catch (FlightRuntimeException fre) { + listener.onError(fre); + return; + } + listener.onNext(new CloseSessionResult(CloseSessionResult.Status.CLOSED)); + listener.onCompleted(); + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/SessionOptionsScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/SessionOptionsScenario.java new file mode 100644 index 0000000000000..c150cfa6ef137 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/SessionOptionsScenario.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.GetSessionOptionsRequest; +import org.apache.arrow.flight.GetSessionOptionsResult; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.ServerSessionMiddleware; +import org.apache.arrow.flight.SessionOptionValue; +import org.apache.arrow.flight.SessionOptionValueFactory; +import org.apache.arrow.flight.SetSessionOptionsRequest; +import org.apache.arrow.flight.SetSessionOptionsResult; +import org.apache.arrow.flight.client.ClientCookieMiddleware; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.memory.BufferAllocator; + +import com.google.common.collect.ImmutableMap; + +/** + * Scenario to exercise Session Options functionality. + */ +final class SessionOptionsScenario implements Scenario { + private final FlightServerMiddleware.Key key = + FlightServerMiddleware.Key.of("sessionmiddleware"); + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { + return new SessionOptionsProducer(key); + } + + @Override + public void buildServer(FlightServer.Builder builder) { + AtomicInteger counter = new AtomicInteger(1000); + builder.middleware(key, new ServerSessionMiddleware.Factory(() -> String.valueOf(counter.getAndIncrement()))); + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient ignored) throws Exception { + final ClientCookieMiddleware.Factory factory = new ClientCookieMiddleware.Factory(); + try (final FlightClient flightClient = FlightClient.builder(allocator, location).intercept(factory).build()) { + final FlightSqlClient client = new FlightSqlClient(flightClient); + + // Set + SetSessionOptionsRequest req1 = new SetSessionOptionsRequest(ImmutableMap.builder() + .put("foolong", SessionOptionValueFactory.makeSessionOptionValue(123L)) + .put("bardouble", SessionOptionValueFactory.makeSessionOptionValue(456.0)) + .put("lol_invalid", SessionOptionValueFactory.makeSessionOptionValue("this won't get set")) + .put("key_with_invalid_value", SessionOptionValueFactory.makeSessionOptionValue("lol_invalid")) + .put("big_ol_string_list", SessionOptionValueFactory.makeSessionOptionValue( + new String[]{"a", "b", "sea", "dee", " ", " ", "geee", "(づ。◕‿‿◕。)づ"})) + .build()); + SetSessionOptionsResult res1 = client.setSessionOptions(req1); + // Some errors + IntegrationAssertions.assertEquals(ImmutableMap.builder() + .put("lol_invalid", new SetSessionOptionsResult.Error(SetSessionOptionsResult.ErrorValue.INVALID_NAME)) + .put("key_with_invalid_value", new SetSessionOptionsResult.Error( + SetSessionOptionsResult.ErrorValue.INVALID_VALUE)) + .build(), + res1.getErrors()); + // Some set, some omitted due to above errors + GetSessionOptionsResult res2 = client.getSessionOptions(new GetSessionOptionsRequest()); + IntegrationAssertions.assertEquals(ImmutableMap.builder() + .put("foolong", SessionOptionValueFactory.makeSessionOptionValue(123L)) + .put("bardouble", SessionOptionValueFactory.makeSessionOptionValue(456.0)) + .put("big_ol_string_list", SessionOptionValueFactory.makeSessionOptionValue( + new String[]{"a", "b", "sea", "dee", " ", " ", "geee", "(づ。◕‿‿◕。)づ"})) + .build(), + res2.getSessionOptions()); + // Update + client.setSessionOptions(new SetSessionOptionsRequest(ImmutableMap.builder() + // Delete + .put("foolong", SessionOptionValueFactory.makeEmptySessionOptionValue()) + // Update + .put("big_ol_string_list", + SessionOptionValueFactory.makeSessionOptionValue("a,b,sea,dee, , ,geee,(づ。◕‿‿◕。)づ")) + .build())); + GetSessionOptionsResult res4 = client.getSessionOptions(new GetSessionOptionsRequest()); + IntegrationAssertions.assertEquals(ImmutableMap.builder() + .put("bardouble", SessionOptionValueFactory.makeSessionOptionValue(456.0)) + .put("big_ol_string_list", + SessionOptionValueFactory.makeSessionOptionValue("a,b,sea,dee, , ,geee,(づ。◕‿‿◕。)づ")) + .build(), + res4.getSessionOptions()); + } + } +} diff --git a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java index cf65e16fac06f..f814427567ae9 100644 --- a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java +++ b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java @@ -78,6 +78,16 @@ void flightSqlExtension() throws Exception { testScenario("flight_sql:extension"); } + @Test + void appMetadataFlightInfoEndpoint() throws Exception { + testScenario("app_metadata_flight_info_endpoint"); + } + + @Test + void sessionOptions() throws Exception { + testScenario("session_options"); + } + void testScenario(String scenarioName) throws Exception { try (final BufferAllocator allocator = new RootAllocator()) { final FlightServer.Builder builder = FlightServer.builder() diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CloseSessionResultListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CloseSessionResultListener.java new file mode 100644 index 0000000000000..e1a5b369fe16c --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CloseSessionResultListener.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.CloseSessionResult; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Result; + +/** Typed StreamListener for closeSession. */ +public class CloseSessionResultListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + CloseSessionResultListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(CloseSessionResult val) { + listener.onNext(new Result(val.serialize().array())); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index e72354513013e..8366c162559af 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -61,15 +61,21 @@ import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.CancelFlightInfoRequest; import org.apache.arrow.flight.CancelFlightInfoResult; +import org.apache.arrow.flight.CloseSessionRequest; +import org.apache.arrow.flight.CloseSessionResult; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.GetSessionOptionsRequest; +import org.apache.arrow.flight.GetSessionOptionsResult; import org.apache.arrow.flight.PutResult; import org.apache.arrow.flight.RenewFlightEndpointRequest; import org.apache.arrow.flight.Result; import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.SetSessionOptionsRequest; +import org.apache.arrow.flight.SetSessionOptionsResult; import org.apache.arrow.flight.SyncPutListener; import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; @@ -917,6 +923,18 @@ public FlightEndpoint renewFlightEndpoint(RenewFlightEndpointRequest request, Ca return client.renewFlightEndpoint(request, options); } + public SetSessionOptionsResult setSessionOptions(SetSessionOptionsRequest request, CallOption... options) { + return client.setSessionOptions(request, options); + } + + public GetSessionOptionsResult getSessionOptions(GetSessionOptionsRequest request, CallOption... options) { + return client.getSessionOptions(request, options); + } + + public CloseSessionResult closeSession(CloseSessionRequest request, CallOption... options) { + return client.closeSession(request, options); + } + @Override public void close() throws Exception { AutoCloseables.close(client); diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java index e2d79129c1fc9..1b3d8c2b487de 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -56,16 +56,22 @@ import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.CancelFlightInfoRequest; import org.apache.arrow.flight.CancelStatus; +import org.apache.arrow.flight.CloseSessionRequest; +import org.apache.arrow.flight.CloseSessionResult; import org.apache.arrow.flight.FlightConstants; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightProducer; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.GetSessionOptionsRequest; +import org.apache.arrow.flight.GetSessionOptionsResult; import org.apache.arrow.flight.PutResult; import org.apache.arrow.flight.RenewFlightEndpointRequest; import org.apache.arrow.flight.Result; import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.SetSessionOptionsRequest; +import org.apache.arrow.flight.SetSessionOptionsResult; import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; @@ -383,6 +389,42 @@ default void doAction(CallContext context, Action action, StreamListener return; } renewFlightEndpoint(request, context, new FlightEndpointListener(listener)); + } else if (actionType.equals(FlightConstants.SET_SESSION_OPTIONS.getType())) { + final SetSessionOptionsRequest request; + try { + request = SetSessionOptionsRequest.deserialize(ByteBuffer.wrap(action.getBody())); + } catch (IOException e) { + listener.onError(CallStatus.INTERNAL + .withDescription("Could not unpack SetSessionOptionsRequest: " + e) + .withCause(e) + .toRuntimeException()); + return; + } + setSessionOptions(request, context, new SetSessionOptionsResultListener(listener)); + } else if (actionType.equals(FlightConstants.GET_SESSION_OPTIONS.getType())) { + final GetSessionOptionsRequest request; + try { + request = GetSessionOptionsRequest.deserialize(ByteBuffer.wrap(action.getBody())); + } catch (IOException e) { + listener.onError(CallStatus.INTERNAL + .withDescription("Could not unpack GetSessionOptionsRequest: " + e) + .withCause(e) + .toRuntimeException()); + return; + } + getSessionOptions(request, context, new GetSessionOptionsResultListener(listener)); + } else if (actionType.equals(FlightConstants.CLOSE_SESSION.getType())) { + final CloseSessionRequest request; + try { + request = CloseSessionRequest.deserialize(ByteBuffer.wrap(action.getBody())); + } catch (IOException e) { + listener.onError(CallStatus.INTERNAL + .withDescription("Could not unpack CloseSessionRequest: " + e) + .withCause(e) + .toRuntimeException()); + return; + } + closeSession(request, context, new CloseSessionResultListener(listener)); } else { throw CallStatus.INVALID_ARGUMENT .withDescription("Unrecognized request: " + action.getType()) @@ -472,6 +514,43 @@ public void onCompleted() { }); } + /** + * Set server session options(s). + * + * @param request The session options to set. For *DBC driver compatibility, servers + * should support converting values from strings. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + default void setSessionOptions(SetSessionOptionsRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Get server session option(s). + * + * @param request The (empty) GetSessionOptionsRequest. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + default void getSessionOptions(GetSessionOptionsRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Close/invalidate the session. + * + * @param request The (empty) CloseSessionRequest. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + default void closeSession(CloseSessionRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + /** * Creates a prepared statement on the server and returns a handle and metadata for in a * {@link ActionCreatePreparedStatementResult} object in a {@link Result} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/GetSessionOptionsResultListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/GetSessionOptionsResultListener.java new file mode 100644 index 0000000000000..4fdffd076243c --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/GetSessionOptionsResultListener.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.GetSessionOptionsResult; +import org.apache.arrow.flight.Result; + +/** Typed StreamListener for getSessionOptions. */ +public class GetSessionOptionsResultListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + GetSessionOptionsResultListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(GetSessionOptionsResult val) { + listener.onNext(new Result(val.serialize().array())); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SetSessionOptionsResultListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SetSessionOptionsResultListener.java new file mode 100644 index 0000000000000..230be2bf1b316 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SetSessionOptionsResultListener.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SetSessionOptionsResult; + +/** Typed StreamListener for setSessionOptions. */ +public class SetSessionOptionsResultListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + SetSessionOptionsResultListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(SetSessionOptionsResult val) { + listener.onNext(new Result(val.serialize().array())); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/testing b/testing index 47f7b56b25683..ad82a736c170e 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 47f7b56b25683202c1fd957668e13f2abafc0f12 +Subproject commit ad82a736c170e97b7c8c035ebd8a801c17eec170