Skip to content

Commit

Permalink
update flightsql tests
Browse files Browse the repository at this point in the history
  • Loading branch information
erratic-pattern committed Apr 13, 2024
1 parent 8657814 commit 64088a2
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions go/arrow/flight/flightsql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,24 +408,26 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() {

func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
const query = "query"
const handle = "handle"
const updatedHandle = "updated handle"

// create and close actions
cmd := &pb.ActionCreatePreparedStatementRequest{Query: query}
action := getAction(cmd)
action.Type = flightsql.CreatePreparedStatementActionType
closeAct := getAction(&pb.ActionClosePreparedStatementRequest{PreparedStatementHandle: []byte(query)})
closeAct := getAction(&pb.ActionClosePreparedStatementRequest{PreparedStatementHandle: []byte(updatedHandle)})
closeAct.Type = flightsql.ClosePreparedStatementActionType

// results from createprepared statement
result := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(query),
actionResult := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(handle),
}
schema := arrow.NewSchema([]arrow.Field{{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
result.ParameterSchema = flight.SerializeSchema(schema, memory.DefaultAllocator)
actionResult.ParameterSchema = flight.SerializeSchema(schema, memory.DefaultAllocator)

// mocked client stream
var out anypb.Any
out.MarshalFrom(result)
out.MarshalFrom(actionResult)
data, _ := proto.Marshal(&out)

createRsp := &mockDoActionClient{}
Expand All @@ -443,7 +445,12 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
s.mockClient.On("DoAction", flightsql.CreatePreparedStatementActionType, action.Body, s.callOpts).Return(createRsp, nil)
s.mockClient.On("DoAction", flightsql.ClosePreparedStatementActionType, closeAct.Body, s.callOpts).Return(closeRsp, nil)

expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)})
expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)})

// mocked DoPut result
doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)}
resdata, _ := proto.Marshal(doPutPreparedStatementResult)
putResult := &pb.PutResult{ AppMetadata: resdata }

// mocked client stream for DoPut
mockedPut := &mockDoPutClient{}
Expand All @@ -452,29 +459,30 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
return proto.Equal(expectedDesc, fd.FlightDescriptor)
})).Return(nil).Twice() // first sends schema message, second sends data
mockedPut.On("CloseSend").Return(nil)
mockedPut.On("Recv").Return((*pb.PutResult)(nil), nil)
mockedPut.On("Recv").Return(putResult, nil)

infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)}
infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)}
desc := getDesc(infoCmd)
s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil)

prepared, err := s.sqlClient.Prepare(context.TODO(), query, s.callOpts...)
s.NoError(err)
defer prepared.Close(context.TODO(), s.callOpts...)

s.Equal(string(prepared.Handle()), "query")
s.Equal(string(prepared.Handle()), handle)

paramSchema := prepared.ParameterSchema()
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, paramSchema, strings.NewReader(`[{"id": 1}]`))
s.NoError(err)
defer rec.Release()

s.Equal(string(prepared.Handle()), "query")
s.Equal(string(prepared.Handle()), handle)

prepared.SetParameters(rec)
info, err := prepared.Execute(context.TODO(), s.callOpts...)
s.NoError(err)
s.Equal(&emptyFlightInfo, info)
s.Equal(string(prepared.Handle()), updatedHandle)
}

func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {
Expand Down Expand Up @@ -516,6 +524,11 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {

expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)})

// mocked DoPut result
doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)}
resdata, _ := proto.Marshal(doPutPreparedStatementResult)
putResult := &pb.PutResult{ AppMetadata: resdata }

// mocked client stream for DoPut
mockedPut := &mockDoPutClient{}
s.mockClient.On("DoPut", s.callOpts).Return(mockedPut, nil)
Expand All @@ -528,7 +541,7 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {
return fd.FlightDescriptor == nil
})).Return(nil).Times(3)
mockedPut.On("CloseSend").Return(nil)
mockedPut.On("Recv").Return((*pb.PutResult)(nil), nil)
mockedPut.On("Recv").Return(putResult, nil)

infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)}
desc := getDesc(infoCmd)
Expand Down

0 comments on commit 64088a2

Please sign in to comment.