Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-40155: [Go][FlightRPC][FlightSQL] Implement Session Management #40284

Merged
merged 9 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/sql/server_session_middleware.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ServerSessionMiddlewareImpl : public ServerSessionMiddleware {

Status CloseSession() override {
const std::lock_guard<std::shared_mutex> l(mutex_);
if (static_cast<bool>(session_)) {
if (!static_cast<bool>(session_)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this a test that was failing that this fixes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the C++ server scenario was failing when I added a test to the Go client scenario that attempts to close the currently open session. I'm not especially familiar with C++, but my understanding is that this was previously saying that "if the session DOES exist, then it is an error to close it" which seems backwards. Is my understanding accurate?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joellubi is correct here and the fix also checks out. Looks like from my commit history that I typoed this while hastily updating the code to work around and omit functionality affected by #40071. This used to be covered in integration testing but was temporarily removed as the aforementioned issue breaks the ability (in C++) to correctly invalidate sessions (on the server side) when the CloseSession Action is called.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notably if you continue on with the same client cookie middleware in any integration test after closing the session from the client, calls are going to get failed out by the C++ server session middleware because the client cookie middleware will submit a now-invalid session token (that the server was unable to invalidate) with the request headers. This is fixable pending #40071.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joellubi Can you add a corresponding fix to the client middleware that matches #40071 which @indigophox mentioned? Or at minimum, file a follow-up issue for this so we don't forget it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zeroshade I think that in the case @indigophox mentioned the client middleware is actually behaving correctly. The problem with the C++ server middleware is that it's unable to invalidate the cookie for an existing session (i.e. it never sends Set-Cookie: arrow_flight_session_id=<session-id>; Max-Age=0) so the client keeps the cookie stored, as it should. Then the next time the client makes any request it sends the cookie (i.e. Cookie: arrow_flight_session_id=<session-id>) along with it, but the server has completely forgotten about that session and considers the token invalid.

Copy link
Member Author

@joellubi joellubi Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@indigophox As I was writing this, the following idea occurred to me:

What if, in the meantime while the C++ issue is getting resolved, we changed the C++ implementation to respond with CloseSessionResult_NOT_CLOSEABLE to all CloseSessionRequest's? Admittedly that's not exactly what I imagined that result status would end up being used for, but it's kind of accurate right now and most importantly it gives some information to clients to be able recover the edge case gracefully.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joellubi ultimately this is scoped to the implementation, so anyone building an app against this can do as they please. The shipped ServerSessionMiddleware just handles the internals; the actual RPC handler can do whatever the app author wants. I am hoping we can close on #40071 fairly soon (before 16.0.0 code freeze?) so the full functionality can be restored. I'll resume the discussion over there and hopefully that can get wrapped up with the existing solution or another one.

My other thought for an if-we-have-to-go-down-that-road workaround is for the middleware to internally flag the token as invalidated, so it can expire it next time it's presented instead of outright treating it as invalid, or alternatively (say you're calling SetSessionOptions) just clobber it with a new token for a new session as appropriate. But again this is dumping work into a temporary hack while we sort out the root issue with middleware handling behaviour.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good discussion potentially but I agree that this is an issue scoped to the implementation, which doesn't seem to be an issue for us on the Go side currently. I'm going to merge this as it looks pretty well covered and handled to me here but feel free to continue this discussion even with this PR merged if desired.

return Status::Invalid("Nonexistent session cannot be closed.");
}
ARROW_RETURN_NOT_OK(factory_->CloseSession(session_id_));
Expand Down
2 changes: 1 addition & 1 deletion dev/archery/archery/integration/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
Scenario(
"session_options",
description="Ensure Flight SQL Sessions work as expected.",
skip_testers={"JS", "C#", "Rust", "Go"}
skip_testers={"JS", "C#", "Rust"}
),
Scenario(
"poll_flight_info",
Expand Down
90 changes: 60 additions & 30 deletions go/arrow/flight/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ type Client interface {
// in order to use the Handshake endpoints of the service.
Authenticate(context.Context, ...grpc.CallOption) error
AuthenticateBasicToken(ctx context.Context, username string, password string, opts ...grpc.CallOption) (context.Context, error)
CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (CancelFlightInfoResult, error)
CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (*CancelFlightInfoResult, error)
Close() error
RenewFlightEndpoint(ctx context.Context, request *RenewFlightEndpointRequest, opts ...grpc.CallOption) (*FlightEndpoint, error)
SetSessionOptions(ctx context.Context, request *SetSessionOptionsRequest, opts ...grpc.CallOption) (*SetSessionOptionsResult, error)
GetSessionOptions(ctx context.Context, request *GetSessionOptionsRequest, opts ...grpc.CallOption) (*GetSessionOptionsResult, error)
CloseSession(ctx context.Context, request *CloseSessionRequest, opts ...grpc.CallOption) (*CloseSessionResult, error)
// join the interface from the FlightServiceClient instead of re-defining all
// the endpoints here.
FlightServiceClient
Expand Down Expand Up @@ -364,26 +367,14 @@ func ReadUntilEOF(stream FlightService_DoActionClient) error {
}
}

func (c *client) CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (result CancelFlightInfoResult, err error) {
var action flight.Action
action.Type = CancelFlightInfoActionType
action.Body, err = proto.Marshal(request)
if err != nil {
return
}
stream, err := c.DoAction(ctx, &action, opts...)
if err != nil {
return
}
res, err := stream.Recv()
func (c *client) CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (*CancelFlightInfoResult, error) {
var result CancelFlightInfoResult
err := handleAction(ctx, c, CancelFlightInfoActionType, request, &result, opts...)
if err != nil {
return
}
if err = proto.Unmarshal(res.Body, &result); err != nil {
return
return nil, err
}
err = ReadUntilEOF(stream)
return

return &result, err
}

func (c *client) Close() error {
Expand All @@ -395,29 +386,68 @@ func (c *client) Close() error {
}

func (c *client) RenewFlightEndpoint(ctx context.Context, request *RenewFlightEndpointRequest, opts ...grpc.CallOption) (*FlightEndpoint, error) {
var err error
var action flight.Action
action.Type = RenewFlightEndpointActionType
action.Body, err = proto.Marshal(request)
var result FlightEndpoint
err := handleAction(ctx, c, RenewFlightEndpointActionType, request, &result, opts...)
if err != nil {
return nil, err
}
stream, err := c.DoAction(ctx, &action, opts...)

return &result, err
}

func (c *client) SetSessionOptions(ctx context.Context, request *SetSessionOptionsRequest, opts ...grpc.CallOption) (*SetSessionOptionsResult, error) {
var result SetSessionOptionsResult
err := handleAction(ctx, c, SetSessionOptionsActionType, request, &result, opts...)
if err != nil {
return nil, err
}
res, err := stream.Recv()

return &result, err
}

func (c *client) GetSessionOptions(ctx context.Context, request *GetSessionOptionsRequest, opts ...grpc.CallOption) (*GetSessionOptionsResult, error) {
var result GetSessionOptionsResult
err := handleAction(ctx, c, GetSessionOptionsActionType, request, &result, opts...)
if err != nil {
return nil, err
}
var renewedEndpoint FlightEndpoint
err = proto.Unmarshal(res.Body, &renewedEndpoint)

return &result, err
}

func (c *client) CloseSession(ctx context.Context, request *CloseSessionRequest, opts ...grpc.CallOption) (*CloseSessionResult, error) {
var result CloseSessionResult
err := handleAction(ctx, c, CloseSessionActionType, request, &result, opts...)
if err != nil {
return nil, err
}
err = ReadUntilEOF(stream)

return &result, err
}

func handleAction[T, U proto.Message](ctx context.Context, client FlightServiceClient, name string, request T, response U, opts ...grpc.CallOption) error {
var (
action flight.Action
err error
)

action.Type = name
action.Body, err = proto.Marshal(request)
if err != nil {
return nil, err
return err
}
return &renewedEndpoint, nil
stream, err := client.DoAction(ctx, &action, opts...)
if err != nil {
return err
}
res, err := stream.Recv()
if err != nil {
return err
}
err = proto.Unmarshal(res.Body, response)
if err != nil {
return err
}

return ReadUntilEOF(stream)
}
14 changes: 13 additions & 1 deletion go/arrow/flight/flightsql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,14 +584,26 @@ func (c *Client) CancelQuery(ctx context.Context, info *flight.FlightInfo, opts
return
}

func (c *Client) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (flight.CancelFlightInfoResult, error) {
func (c *Client) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (*flight.CancelFlightInfoResult, error) {
return c.Client.CancelFlightInfo(ctx, request, opts...)
}

func (c *Client) RenewFlightEndpoint(ctx context.Context, request *flight.RenewFlightEndpointRequest, opts ...grpc.CallOption) (*flight.FlightEndpoint, error) {
return c.Client.RenewFlightEndpoint(ctx, request, opts...)
}

func (c *Client) SetSessionOptions(ctx context.Context, request *flight.SetSessionOptionsRequest, opts ...grpc.CallOption) (*flight.SetSessionOptionsResult, error) {
return c.Client.SetSessionOptions(ctx, request, opts...)
}

func (c *Client) GetSessionOptions(ctx context.Context, request *flight.GetSessionOptionsRequest, opts ...grpc.CallOption) (*flight.GetSessionOptionsResult, error) {
return c.Client.GetSessionOptions(ctx, request, opts...)
}

func (c *Client) CloseSession(ctx context.Context, request *flight.CloseSessionRequest, opts ...grpc.CallOption) (*flight.CloseSessionResult, error) {
return c.Client.CloseSession(ctx, request, opts...)
}

func (c *Client) BeginTransaction(ctx context.Context, opts ...grpc.CallOption) (*Txn, error) {
request := &pb.ActionBeginTransactionRequest{}
action, err := packAction(BeginTransactionActionType, request)
Expand Down
25 changes: 20 additions & 5 deletions go/arrow/flight/flightsql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,31 @@ func (m *FlightServiceClientMock) AuthenticateBasicToken(_ context.Context, user
return args.Get(0).(context.Context), args.Error(1)
}

func (m *FlightServiceClientMock) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (flight.CancelFlightInfoResult, error) {
func (m *FlightServiceClientMock) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (*flight.CancelFlightInfoResult, error) {
args := m.Called(request, opts)
return args.Get(0).(flight.CancelFlightInfoResult), args.Error(1)
return args.Get(0).(*flight.CancelFlightInfoResult), args.Error(1)
}

func (m *FlightServiceClientMock) RenewFlightEndpoint(ctx context.Context, request *flight.RenewFlightEndpointRequest, opts ...grpc.CallOption) (*flight.FlightEndpoint, error) {
args := m.Called(request, opts)
return args.Get(0).(*flight.FlightEndpoint), args.Error(1)
}

func (m *FlightServiceClientMock) SetSessionOptions(ctx context.Context, request *flight.SetSessionOptionsRequest, opts ...grpc.CallOption) (*flight.SetSessionOptionsResult, error) {
args := m.Called(request, opts)
return args.Get(0).(*flight.SetSessionOptionsResult), args.Error(1)
}

func (m *FlightServiceClientMock) GetSessionOptions(ctx context.Context, request *flight.GetSessionOptionsRequest, opts ...grpc.CallOption) (*flight.GetSessionOptionsResult, error) {
args := m.Called(request, opts)
return args.Get(0).(*flight.GetSessionOptionsResult), args.Error(1)
}

func (m *FlightServiceClientMock) CloseSession(ctx context.Context, request *flight.CloseSessionRequest, opts ...grpc.CallOption) (*flight.CloseSessionResult, error) {
args := m.Called(request, opts)
return args.Get(0).(*flight.CloseSessionResult), args.Error(1)
}

func (m *FlightServiceClientMock) Close() error {
return m.Called().Error(0)
}
Expand Down Expand Up @@ -639,10 +654,10 @@ func (s *FlightSqlClientSuite) TestCancelFlightInfo() {
mockedCancelResult := flight.CancelFlightInfoResult{
Status: flight.CancelStatusCancelled,
}
s.mockClient.On("CancelFlightInfo", &request, s.callOpts).Return(mockedCancelResult, nil)
s.mockClient.On("CancelFlightInfo", &request, s.callOpts).Return(&mockedCancelResult, nil)
cancelResult, err := s.sqlClient.CancelFlightInfo(context.TODO(), &request, s.callOpts...)
s.NoError(err)
s.Equal(mockedCancelResult, cancelResult)
s.Equal(&mockedCancelResult, cancelResult)
}

func (s *FlightSqlClientSuite) TestRenewFlightEndpoint() {
Expand Down Expand Up @@ -671,7 +686,7 @@ func (s *FlightSqlClientSuite) TestPreparedStatementLoadFromResult() {
result := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(query),
}

parameterSchemaResult := arrow.NewSchema([]arrow.Field{{Name: "p_id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
result.ParameterSchema = flight.SerializeSchema(parameterSchemaResult, memory.DefaultAllocator)
datasetSchemaResult := arrow.NewSchema([]arrow.Field{{Name: "ds_id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
Expand Down
81 changes: 81 additions & 0 deletions go/arrow/flight/flightsql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,18 @@ func (BaseServer) EndSavepoint(context.Context, ActionEndSavepointRequest) error
return status.Error(codes.Unimplemented, "EndSavepoint not implemented")
}

func (BaseServer) SetSessionOptions(context.Context, *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) {
return nil, status.Error(codes.Unimplemented, "SetSessionOptions not implemented")
}

func (BaseServer) GetSessionOptions(context.Context, *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) {
return nil, status.Error(codes.Unimplemented, "GetSessionOptions not implemented")
}

func (BaseServer) CloseSession(context.Context, *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) {
return nil, status.Error(codes.Unimplemented, "CloseSession not implemented")
}

// Server is the required interface for a FlightSQL server. It is implemented by
// BaseServer which must be embedded in any implementation. The default
// implementation by BaseServer for each of these (except GetSqlInfo)
Expand Down Expand Up @@ -676,6 +688,12 @@ type Server interface {
PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error)
// PollFlightInfoPreparedStatement handles polling for query execution.
PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error)
// SetSessionOptions sets option(s) for the current server session.
SetSessionOptions(context.Context, *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error)
// GetSessionOptions gets option(s) for the current server session.
GetSessionOptions(context.Context, *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error)
// CloseSession closes/invalidates the current server session.
CloseSession(context.Context, *flight.CloseSessionRequest) (*flight.CloseSessionResult, error)

mustEmbedBaseServer()
}
Expand Down Expand Up @@ -1262,6 +1280,69 @@ func (f *flightSqlServer) DoAction(cmd *flight.Action, stream flight.FlightServi
}

return stream.Send(&pb.Result{})
case flight.SetSessionOptionsActionType:
var (
request flight.SetSessionOptionsRequest
err error
)

if err = proto.Unmarshal(cmd.Body, &request); err != nil {
return status.Errorf(codes.InvalidArgument, "unable to unmarshal SetSessionOptionsRequest: %s", err.Error())
}

response, err := f.srv.SetSessionOptions(stream.Context(), &request)
if err != nil {
return err
}

out := &pb.Result{}
out.Body, err = proto.Marshal(response)
if err != nil {
return err
}
return stream.Send(out)
case flight.GetSessionOptionsActionType:
var (
request flight.GetSessionOptionsRequest
err error
)

if err = proto.Unmarshal(cmd.Body, &request); err != nil {
return status.Errorf(codes.InvalidArgument, "unable to unmarshal GetSessionOptionsRequest: %s", err.Error())
}

response, err := f.srv.GetSessionOptions(stream.Context(), &request)
if err != nil {
return err
}

out := &pb.Result{}
out.Body, err = proto.Marshal(response)
if err != nil {
return err
}
return stream.Send(out)
case flight.CloseSessionActionType:
var (
request flight.CloseSessionRequest
err error
)

if err = proto.Unmarshal(cmd.Body, &request); err != nil {
return status.Errorf(codes.InvalidArgument, "unable to unmarshal CloseSessionRequest: %s", err.Error())
}

response, err := f.srv.CloseSession(stream.Context(), &request)
if err != nil {
return err
}

out := &pb.Result{}
out.Body, err = proto.Marshal(response)
if err != nil {
return err
}
return stream.Send(out)
default:
return status.Error(codes.InvalidArgument, "the defined request is invalid.")
}
Expand Down
Loading
Loading