Skip to content

Commit

Permalink
GH-37720: [Go][FlightSQL] Implement stateless prepared statements
Browse files Browse the repository at this point in the history
  • Loading branch information
erratic-pattern committed Mar 1, 2024
1 parent 0dbbd43 commit 76533a9
Show file tree
Hide file tree
Showing 7 changed files with 895 additions and 928 deletions.
10 changes: 5 additions & 5 deletions go/arrow/flight/flightsql/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,16 +844,16 @@ func (s *MockServer) CreatePreparedStatement(ctx context.Context, req flightsql.
}, nil
}

func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flightsql.PreparedStatementQuery, r flight.MessageReader, w flight.MetadataWriter) error {
func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flightsql.PreparedStatementQuery, r flight.MessageReader, w flight.MetadataWriter) ([]byte, error) {
if s.ExpectedPreparedStatementSchema != nil {
if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) {
return errors.New("parameter schema: unexpected")
return nil, errors.New("parameter schema: unexpected")
}
return nil
return qry.GetPreparedStatementHandle(), nil
}

if s.PreparedStatementParameterSchema != nil && !s.PreparedStatementParameterSchema.Equal(r.Schema()) {
return fmt.Errorf("parameter schema: %w", arrow.ErrInvalid)
return nil, fmt.Errorf("parameter schema: %w", arrow.ErrInvalid)
}

// GH-35328: it's rare, but this function can complete execution and return
Expand All @@ -867,7 +867,7 @@ func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flight
for r.Next() {
}

return nil
return qry.GetPreparedStatementHandle(), nil
}

func (s *MockServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
Expand Down
8 changes: 4 additions & 4 deletions go/arrow/flight/flightsql/example/sqlite_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,21 +618,21 @@ func getParamsForStatement(rdr flight.MessageReader) (params [][]interface{}, er
return params, rdr.Err()
}

func (s *SQLiteFlightSQLServer) DoPutPreparedStatementQuery(_ context.Context, cmd flightsql.PreparedStatementQuery, rdr flight.MessageReader, _ flight.MetadataWriter) error {
func (s *SQLiteFlightSQLServer) DoPutPreparedStatementQuery(_ context.Context, cmd flightsql.PreparedStatementQuery, rdr flight.MessageReader, _ flight.MetadataWriter) ([]byte, error) {
val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle()))
if !ok {
return status.Error(codes.InvalidArgument, "prepared statement not found")
return nil, status.Error(codes.InvalidArgument, "prepared statement not found")
}

stmt := val.(Statement)
args, err := getParamsForStatement(rdr)
if err != nil {
return status.Errorf(codes.Internal, "error gathering parameters for prepared statement query: %s", err.Error())
return nil, status.Errorf(codes.Internal, "error gathering parameters for prepared statement query: %s", err.Error())
}

stmt.params = args
s.prepared.Store(string(cmd.GetPreparedStatementHandle()), stmt)
return nil
return cmd.GetPreparedStatementHandle(), nil
}

func (s *SQLiteFlightSQLServer) DoPutPreparedStatementUpdate(ctx context.Context, cmd flightsql.PreparedStatementUpdate, rdr flight.MessageReader) (int64, error) {
Expand Down
17 changes: 13 additions & 4 deletions go/arrow/flight/flightsql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ func (BaseServer) DoPutCommandSubstraitPlan(context.Context, StatementSubstraitP
return 0, status.Error(codes.Unimplemented, "DoPutCommandSubstraitPlan not implemented")
}

func (BaseServer) DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) error {
return status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery not implemented")
func (BaseServer) DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) ([]byte, error) {
return nil, status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery not implemented")
}

func (BaseServer) DoPutPreparedStatementUpdate(context.Context, PreparedStatementUpdate, flight.MessageReader) (int64, error) {
Expand Down Expand Up @@ -650,7 +650,7 @@ type Server interface {
// Currently anything written to the writer will be ignored. It is in the
// interface for potential future enhancements to avoid having to change
// the interface in the future.
DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) error
DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) ([]byte, error)
// DoPutPreparedStatementUpdate executes an update SQL Prepared statement
// for the specified statement handle. The reader allows providing a sequence
// of uploaded record batches to bind the parameters to. Returns the number
Expand Down Expand Up @@ -954,7 +954,16 @@ func (f *flightSqlServer) DoPut(stream flight.FlightService_DoPutServer) error {
}
return stream.Send(out)
case *pb.CommandPreparedStatementQuery:
return f.srv.DoPutPreparedStatementQuery(stream.Context(), cmd, rdr, &putMetadataWriter{stream})
handle, err := f.srv.DoPutPreparedStatementQuery(stream.Context(), cmd, rdr, &putMetadataWriter{stream})
if err != nil {
return err
}
result := pb.DoPutPreparedStatementResult{PreparedStatementHandle: handle}
out := &flight.PutResult{}
if out.AppMetadata, err = proto.Marshal(&result); err != nil {
return status.Errorf(codes.Internal, "failed to marshal PutResult: %s", err.Error())
}
return stream.Send(out)
case *pb.CommandPreparedStatementUpdate:
recordCount, err := f.srv.DoPutPreparedStatementUpdate(stream.Context(), cmd, rdr)
if err != nil {
Expand Down
Loading

0 comments on commit 76533a9

Please sign in to comment.