Skip to content

Commit

Permalink
fix(go/adbc): fix crash on map type
Browse files Browse the repository at this point in the history
Fixes apache#853.
  • Loading branch information
lidavidm committed Jun 27, 2023
1 parent a670288 commit 7e55741
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 4 deletions.
105 changes: 105 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ func TestCookies(t *testing.T) {
suite.Run(t, &CookieTests{})
}

func TestDataType(t *testing.T) {
suite.Run(t, &DataTypeTests{})
}

// ---- AuthN Tests --------------------

type AuthnTestServer struct {
Expand Down Expand Up @@ -522,3 +526,104 @@ func (suite *CookieTests) TestCookieUsage() {
suite.Require().NoError(err)
defer reader.Release()
}

// ---- Data Type Tests --------------------
type DataTypeTestServer struct {
flightsql.BaseServer

cur time.Time
}

func (server *DataTypeTestServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
tkt, _ := flightsql.CreateStatementQueryTicket([]byte(cmd.GetQuery()))
info := &flight.FlightInfo{
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{
{Ticket: &flight.Ticket{Ticket: tkt}},
},
TotalRecords: -1,
TotalBytes: -1,
}

return info, nil
}

var (
SchemaListInt3 = arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.FixedSizeListOf(3, arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
SchemaListInt = arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.ListOf(arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
SchemaLargeListInt = arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.LargeListOf(arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
SchemaMapIntInt = arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.MapOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32), Nullable: true}}, nil)
)

func (server *DataTypeTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
var schema *arrow.Schema
var record arrow.Record
var err error

cmd := string(tkt.GetStatementHandle())
switch cmd {
case "list[int, 3]":
schema = SchemaListInt3
record, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": [1, 2, 3]}]`))
case "list[int]":
schema = SchemaListInt
record, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": [1]}]`))
case "large_list[int]":
schema = SchemaLargeListInt
record, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": [1]}]`))
case "map[int]int":
schema = SchemaMapIntInt
record, _, err = array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"a": null}]`))
default:
return nil, nil, fmt.Errorf("Unknown command: '%s'", cmd)
}

if err != nil {
return nil, nil, err
}

ch := make(chan flight.StreamChunk)
go func() {
defer close(ch)
ch <- flight.StreamChunk{
Data: record,
}
}()
return schema, ch, nil
}

type DataTypeTests struct {
ServerBasedTests
}

func (suite *DataTypeTests) SetupSuite() {
suite.DoSetupSuite(&DataTypeTestServer{}, nil, map[string]string{})
}

func (suite *DataTypeTests) DoTestCase(name string, schema *arrow.Schema) {
stmt, err := suite.cnxn.NewStatement()
suite.NoError(err)
defer stmt.Close()

suite.NoError(stmt.SetSqlQuery(name))
reader, _, err := stmt.ExecuteQuery(context.Background())
suite.NoError(err)
suite.Equal(reader.Schema(), schema)
defer reader.Release()
}

func (suite *DataTypeTests) TestListInt3() {
suite.DoTestCase("list[int, 3]", SchemaListInt3)
}

func (suite *DataTypeTests) TestLargeListInt() {
suite.DoTestCase("large_list[int]", SchemaLargeListInt)
}

func (suite *DataTypeTests) TestListInt() {
suite.DoTestCase("list[int]", SchemaListInt)
}

func (suite *DataTypeTests) TestMapIntInt() {
suite.DoTestCase("map[int]int", SchemaMapIntInt)
}
14 changes: 10 additions & 4 deletions go/adbc/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ func RemoveSchemaMetadata(schema *arrow.Schema) *arrow.Schema {
func removeFieldMetadata(field *arrow.Field) arrow.Field {
fieldType := field.Type

if nestedType, ok := field.Type.(arrow.NestedType); ok {
if ty, ok := field.Type.(*arrow.MapType); ok {
// XXX: this can't handle nonstandard field names or
// nullability, but it's impossible to do otherwise in arrow-go
key := ty.KeyField()
value := ty.ValueField()
mapType := arrow.MapOf(removeFieldMetadata(&key).Type, removeFieldMetadata(&value).Type)
mapType.KeysSorted = ty.KeysSorted
fieldType = mapType
} else if nestedType, ok := field.Type.(arrow.NestedType); ok {
childFields := make([]arrow.Field, len(nestedType.Fields()))
for i, field := range nestedType.Fields() {
childFields[i] = removeFieldMetadata(&field)
Expand All @@ -46,9 +54,7 @@ func removeFieldMetadata(field *arrow.Field) arrow.Field {
case *arrow.LargeListType:
fieldType = arrow.LargeListOfField(childFields[0])
case *arrow.MapType:
mapType := arrow.MapOf(childFields[0].Type, childFields[1].Type)
mapType.KeysSorted = ty.KeysSorted
fieldType = mapType
panic("map type should be handled above")
case *arrow.SparseUnionType:
fieldType = arrow.SparseUnionOf(childFields, ty.TypeCodes())
case *arrow.StructType:
Expand Down

0 comments on commit 7e55741

Please sign in to comment.