Skip to content

Commit

Permalink
fix(go/adbc): fix crash on map type
Browse files Browse the repository at this point in the history
Fixes apache#853.
  • Loading branch information
lidavidm committed Jul 5, 2023
1 parent 7bacf87 commit bd5ddd0
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 1 deletion.
103 changes: 103 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ func TestCookies(t *testing.T) {
suite.Run(t, &CookieTests{})
}

func TestDataType(t *testing.T) {
suite.Run(t, &DataTypeTests{})
}

// ---- AuthN Tests --------------------

type AuthnTestServer struct {
Expand Down Expand Up @@ -524,3 +528,102 @@ func (suite *CookieTests) TestCookieUsage() {
suite.Require().NoError(err)
defer reader.Release()
}

// ---- Data Type Tests --------------------
type DataTypeTestServer struct {
flightsql.BaseServer
}

func (server *DataTypeTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
tkt, _ := flightsql.CreateStatementQueryTicket([]byte(cmd.GetQuery()))
info := &flight.FlightInfo{
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{
{Ticket: &flight.Ticket{Ticket: tkt}},
},
TotalRecords: -1,
TotalBytes: -1,
}

return info, nil
}

var (
SchemaListInt3 = arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.FixedSizeListOf(3, arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
SchemaListInt = arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.ListOf(arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
SchemaLargeListInt = arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.LargeListOf(arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
SchemaMapIntInt = arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.MapOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
)

func (server *DataTypeTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
var schema *arrow.Schema
var record arrow.Record
var err error

cmd := string(tkt.GetStatementHandle())
switch cmd {
case "list[int, 3]":
schema = SchemaListInt3
record, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": [1, 2, 3]}]`))
case "list[int]":
schema = SchemaListInt
record, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": [1]}]`))
case "large_list[int]":
schema = SchemaLargeListInt
record, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": [1]}]`))
case "map[int]int":
schema = SchemaMapIntInt
record, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": null}]`))
default:
return nil, nil, fmt.Errorf("Unknown command: '%s'", cmd)
}

if err != nil {
return nil, nil, err
}

ch := make(chan flight.StreamChunk)
go func() {
defer close(ch)
ch <- flight.StreamChunk{
Data: record,
}
}()
return schema, ch, nil
}

type DataTypeTests struct {
ServerBasedTests
}

func (suite *DataTypeTests) SetupSuite() {
suite.DoSetupSuite(&DataTypeTestServer{}, nil, map[string]string{})
}

func (suite *DataTypeTests) DoTestCase(name string, schema *arrow.Schema) {
stmt, err := suite.cnxn.NewStatement()
suite.NoError(err)
defer stmt.Close()

suite.NoError(stmt.SetSqlQuery(name))
reader, _, err := stmt.ExecuteQuery(context.Background())
suite.NoError(err)
suite.Equal(reader.Schema(), schema)
defer reader.Release()
}

func (suite *DataTypeTests) TestListInt3() {
suite.DoTestCase("list[int, 3]", SchemaListInt3)
}

func (suite *DataTypeTests) TestLargeListInt() {
suite.DoTestCase("large_list[int]", SchemaLargeListInt)
}

func (suite *DataTypeTests) TestListInt() {
suite.DoTestCase("list[int]", SchemaListInt)
}

func (suite *DataTypeTests) TestMapIntInt() {
suite.DoTestCase("map[int]int", SchemaMapIntInt)
}
10 changes: 9 additions & 1 deletion go/adbc/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@ func removeFieldMetadata(field *arrow.Field) arrow.Field {
case *arrow.LargeListType:
fieldType = arrow.LargeListOfField(childFields[0])
case *arrow.MapType:
mapType := arrow.MapOf(childFields[0].Type, childFields[1].Type)
// XXX: arrow-go doesn't let us build a map type from fields (so
// nonstandard field names or nullability will be lost here)

// child must be struct
structType := ty.Elem().(*arrow.StructType)
// struct must have two children
keyType := structType.Field(0).Type
itemType := structType.Field(1).Type
mapType := arrow.MapOf(keyType, itemType)
mapType.KeysSorted = ty.KeysSorted
fieldType = mapType
case *arrow.SparseUnionType:
Expand Down

0 comments on commit bd5ddd0

Please sign in to comment.