Skip to content

Commit

Permalink
Merge pull request #306 from taniabogatsch/array-support
Browse files Browse the repository at this point in the history
ARRAY type support
  • Loading branch information
taniabogatsch authored Nov 8, 2024
2 parents c98e856 + f4413b7 commit e1b55d8
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 57 deletions.
29 changes: 29 additions & 0 deletions appender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,35 @@ func TestAppenderList(t *testing.T) {
cleanupAppender(t, c, con, a)
}

func TestAppenderArray(t *testing.T) {
t.Parallel()
c, con, a := prepareAppender(t, `CREATE TABLE test (string_array VARCHAR[3])`)

count := 10
expected := Composite[[3]string]{[3]string{"a", "b", "c"}}
for i := 0; i < count; i++ {
require.NoError(t, a.AppendRow([]string{"a", "b", "c"}))
require.NoError(t, a.AppendRow(expected.Get()))
}
require.NoError(t, a.Flush())

// Verify results.
res, err := sql.OpenDB(c).QueryContext(context.Background(), `SELECT * FROM test`)
require.NoError(t, err)

i := 0
for res.Next() {
var r Composite[[3]string]
require.NoError(t, res.Scan(&r))
require.Equal(t, expected, r)
i++
}

require.Equal(t, 2*count, i)
require.NoError(t, res.Close())
cleanupAppender(t, c, con, a)
}

func TestAppenderNested(t *testing.T) {
t.Parallel()
c, con, a := prepareAppender(t, createNestedDataTableSQL)
Expand Down
6 changes: 6 additions & 0 deletions duckdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,12 @@ func TestTypeNamesAndScanTypes(t *testing.T) {
value: Map{int32(1): "a", int32(5): "e"},
typeName: "MAP(INTEGER, VARCHAR)",
},
// DUCKDB_TYPE_ARRAY
{
sql: "SELECT ['duck', 'goose', NULL]::VARCHAR[3] AS col",
value: []any{"duck", "goose", nil},
typeName: "VARCHAR[3]",
},
// DUCKDB_TYPE_UUID
{
sql: "SELECT '53b4e983-b287-481a-94ad-6e3c90489913'::UUID AS col",
Expand Down
6 changes: 6 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ func conversionError(actual int, min int, max int) error {
return fmt.Errorf("%s: cannot convert %d, minimum: %d, maximum: %d", convertErrMsg, actual, min, max)
}

func invalidInputError(actual string, expected string) error {
return fmt.Errorf("%s: expected %s, got %s", invalidInputErrMsg, expected, actual)
}

func structFieldError(actual string, expected string) error {
return fmt.Errorf("%s: expected %s, got %s", structFieldErrMsg, expected, actual)
}
Expand Down Expand Up @@ -66,6 +70,7 @@ const (
duckdbErrMsg = "duckdb error"
castErrMsg = "cast error"
convertErrMsg = "conversion error"
invalidInputErrMsg = "invalid input"
structFieldErrMsg = "invalid STRUCT field"
columnCountErrMsg = "invalid column count"
unsupportedTypeErrMsg = "unsupported data type"
Expand Down Expand Up @@ -109,6 +114,7 @@ var (
errEmptyName = errors.New("empty name")
errInvalidDecimalWidth = fmt.Errorf("the DECIMAL with must be between 1 and %d", max_decimal_width)
errInvalidDecimalScale = errors.New("the DECIMAL scale must be less than or equal to the width")
errInvalidArraySize = errors.New("invalid ARRAY size")
errSetSQLNULLValue = errors.New("cannot write to a NULL column")

errScalarUDFCreate = errors.New("could not create scalar UDF")
Expand Down
9 changes: 8 additions & 1 deletion errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestErrAppender(t *testing.T) {
c, err := NewConnector("", nil)
require.NoError(t, err)

_, err = sql.OpenDB(c).Exec(`CREATE TABLE test (int_array INTEGER[2])`)
_, err = sql.OpenDB(c).Exec(`CREATE TABLE test (bit_col BIT)`)
require.NoError(t, err)

con, err := c.Connect(context.Background())
Expand Down Expand Up @@ -171,6 +171,13 @@ func TestErrAppender(t *testing.T) {
testError(t, err, errAppenderAppendRow.Error(), errUnsupportedMapKeyType.Error())
cleanupAppender(t, c, con, a)
})

t.Run(invalidInputErrMsg, func(t *testing.T) {
c, con, a := prepareAppender(t, `CREATE TABLE test (col INT[3])`)
err := a.AppendRow([]int32{1, 2})
testError(t, err, errAppenderAppendRow.Error(), invalidInputErrMsg)
cleanupAppender(t, c, con, a)
})
}

func TestErrAppend(t *testing.T) {
Expand Down
35 changes: 28 additions & 7 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ func (r *rows) ColumnTypeScanType(index int) reflect.Type {
return reflect.TypeOf(map[string]any{})
case TYPE_MAP:
return reflect.TypeOf(Map{})
case TYPE_ARRAY:
return reflect.TypeOf([]any{})
case TYPE_UUID:
return reflect.TypeOf([]byte{})
default:
Expand All @@ -138,7 +140,7 @@ func (r *rows) ColumnTypeScanType(index int) reflect.Type {
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
t := Type(C.duckdb_column_type(&r.res, C.idx_t(index)))
switch t {
case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP:
case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP, TYPE_ARRAY:
// Only allocate the logical type if necessary.
logicalType := C.duckdb_column_logical_type(&r.res, C.idx_t(index))
defer C.duckdb_destroy_logical_type(&logicalType)
Expand Down Expand Up @@ -167,25 +169,36 @@ func logicalTypeName(logicalType C.duckdb_logical_type) string {
t := Type(C.duckdb_get_type_id(logicalType))
switch t {
case TYPE_DECIMAL:
width := C.duckdb_decimal_width(logicalType)
scale := C.duckdb_decimal_scale(logicalType)
return fmt.Sprintf("DECIMAL(%d,%d)", width, scale)
return logicalTypeNameDecimal(logicalType)
case TYPE_ENUM:
// The C API does not expose ENUM names.
return "ENUM"
case TYPE_LIST:
childType := C.duckdb_list_type_child_type(logicalType)
defer C.duckdb_destroy_logical_type(&childType)
return logicalTypeName(childType) + "[]"
return logicalTypeNameList(logicalType)
case TYPE_STRUCT:
return logicalTypeNameStruct(logicalType)
case TYPE_MAP:
return logicalTypeNameMap(logicalType)
case TYPE_ARRAY:
return logicalTypeNameArray(logicalType)
default:
return typeToStringMap[t]
}
}

func logicalTypeNameDecimal(logicalType C.duckdb_logical_type) string {
width := C.duckdb_decimal_width(logicalType)
scale := C.duckdb_decimal_scale(logicalType)
return fmt.Sprintf("DECIMAL(%d,%d)", int(width), int(scale))
}

func logicalTypeNameList(logicalType C.duckdb_logical_type) string {
childType := C.duckdb_list_type_child_type(logicalType)
defer C.duckdb_destroy_logical_type(&childType)
childName := logicalTypeName(childType)
return fmt.Sprintf("%s[]", childName)
}

func logicalTypeNameStruct(logicalType C.duckdb_logical_type) string {
count := int(C.duckdb_struct_type_child_count(logicalType))
name := "STRUCT("
Expand Down Expand Up @@ -217,6 +230,14 @@ func logicalTypeNameMap(logicalType C.duckdb_logical_type) string {
return fmt.Sprintf("MAP(%s, %s)", logicalTypeName(keyType), logicalTypeName(valueType))
}

func logicalTypeNameArray(logicalType C.duckdb_logical_type) string {
size := C.duckdb_array_type_array_size(logicalType)
childType := C.duckdb_array_type_child_type(logicalType)
defer C.duckdb_destroy_logical_type(&childType)
childName := logicalTypeName(childType)
return fmt.Sprintf("%s[%d]", childName, int(size))
}

func escapeStructFieldName(s string) string {
// DuckDB escapes STRUCT field names by doubling double quotes, then wrapping in double quotes.
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
Expand Down
1 change: 0 additions & 1 deletion type.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ const (
var unsupportedTypeToStringMap = map[Type]string{
TYPE_INVALID: "INVALID",
TYPE_UHUGEINT: "UHUGEINT",
TYPE_ARRAY: "ARRAY",
TYPE_UNION: "UNION",
TYPE_BIT: "BIT",
TYPE_ANY: "ANY",
Expand Down
31 changes: 31 additions & 0 deletions type_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type baseTypeInfo struct {
structEntries []StructEntry
decimalWidth uint8
decimalScale uint8
arrayLength uint64
// The internal type for ENUM and DECIMAL values.
internalType Type
}
Expand Down Expand Up @@ -102,6 +103,8 @@ func NewTypeInfo(t Type) (TypeInfo, error) {
return nil, getError(errAPI, tryOtherFuncError(funcName(NewStructInfo)))
case TYPE_MAP:
return nil, getError(errAPI, tryOtherFuncError(funcName(NewMapInfo)))
case TYPE_ARRAY:
return nil, getError(errAPI, tryOtherFuncError(funcName(NewArrayInfo)))
case TYPE_SQLNULL:
return nil, getError(errAPI, unsupportedTypeError(typeToStringMap[t]))
}
Expand Down Expand Up @@ -232,6 +235,25 @@ func NewMapInfo(keyInfo TypeInfo, valueInfo TypeInfo) (TypeInfo, error) {
return info, nil
}

// NewArrayInfo returns ARRAY type information.
// childInfo contains the type information of the ARRAY's elements.
// size is the ARRAY's fixed size.
func NewArrayInfo(childInfo TypeInfo, size uint64) (TypeInfo, error) {
if childInfo == nil {
return nil, getError(errAPI, interfaceIsNilError("childInfo"))
}
if size == 0 {
return nil, getError(errAPI, errInvalidArraySize)
}

info := &typeInfo{
baseTypeInfo: baseTypeInfo{Type: TYPE_ARRAY, arrayLength: size},
childTypes: make([]TypeInfo, 1),
}
info.childTypes[0] = childInfo
return info, nil
}

func (info *typeInfo) logicalType() C.duckdb_logical_type {
switch info.Type {
case TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INTEGER, TYPE_BIGINT, TYPE_UTINYINT, TYPE_USMALLINT,
Expand All @@ -250,6 +272,8 @@ func (info *typeInfo) logicalType() C.duckdb_logical_type {
return info.logicalStructType()
case TYPE_MAP:
return info.logicalMapType()
case TYPE_ARRAY:
return info.logicalArrayType()
}
return nil
}
Expand Down Expand Up @@ -315,6 +339,13 @@ func (info *typeInfo) logicalMapType() C.duckdb_logical_type {
return logicalType
}

func (info *typeInfo) logicalArrayType() C.duckdb_logical_type {
child := info.childTypes[0].logicalType()
logicalType := C.duckdb_create_array_type(child, C.idx_t(info.arrayLength))
C.duckdb_destroy_logical_type(&child)
return logicalType
}

func funcName(i interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
}
38 changes: 35 additions & 3 deletions type_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func getTypeInfos(t *testing.T, useAny bool) []testTypeInfo {
continue
}
switch k {
case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP, TYPE_SQLNULL:
case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP, TYPE_ARRAY, TYPE_SQLNULL:
continue
}
primitiveTypes = append(primitiveTypes, k)
Expand Down Expand Up @@ -160,7 +160,32 @@ func getTypeInfos(t *testing.T, useAny bool) []testTypeInfo {
},
}

testTypeInfos = append(testTypeInfos, decimalTypeInfo, enumTypeInfo, listTypeInfo, nestedListTypeInfo, structTypeInfo, nestedStructTypeInfo, mapTypeInfo)
primitiveInfo, err := NewTypeInfo(TYPE_INTEGER)
require.NoError(t, err)

info, err = NewArrayInfo(primitiveInfo, 3)
require.NoError(t, err)
arrayTypeInfo := testTypeInfo{
TypeInfo: info,
testTypeValues: testTypeValues{
input: `[4::INT, 8::INT, 16::INT]`,
output: `[4, 8, 16]`,
},
}

info, err = NewArrayInfo(arrayTypeInfo, 2)
require.NoError(t, err)
nestedArrayTypeInfo := testTypeInfo{
TypeInfo: info,
testTypeValues: testTypeValues{
input: `[[4::INT, 8::INT, 16::INT], [3::INT, 6::INT, 9::INT]]`,
output: `[[4, 8, 16], [3, 6, 9]]`,
},
}

testTypeInfos = append(testTypeInfos, decimalTypeInfo, enumTypeInfo,
listTypeInfo, nestedListTypeInfo, structTypeInfo, nestedStructTypeInfo, mapTypeInfo,
arrayTypeInfo, nestedArrayTypeInfo)
return testTypeInfos
}

Expand All @@ -178,7 +203,7 @@ func TestErrTypeInfo(t *testing.T) {
t.Parallel()

var incorrectTypes []Type
incorrectTypes = append(incorrectTypes, TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP)
incorrectTypes = append(incorrectTypes, TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP, TYPE_ARRAY)

for _, incorrect := range incorrectTypes {
_, err := NewTypeInfo(incorrect)
Expand Down Expand Up @@ -226,6 +251,10 @@ func TestErrTypeInfo(t *testing.T) {
nilStructEntry, err := NewStructEntry(nil, "hello")
require.NoError(t, err)

// Invalid ARRAY entry.
_, err = NewArrayInfo(validInfo, 0)
testError(t, err, errAPI.Error(), errInvalidArraySize.Error())

// Invalid interfaces.
_, err = NewListInfo(nil)
testError(t, err, errAPI.Error(), interfaceIsNilErrMsg)
Expand All @@ -247,4 +276,7 @@ func TestErrTypeInfo(t *testing.T) {
testError(t, err, errAPI.Error(), interfaceIsNilErrMsg)
_, err = NewMapInfo(validInfo, nil)
testError(t, err, errAPI.Error(), interfaceIsNilErrMsg)

_, err = NewArrayInfo(nil, 3)
testError(t, err, errAPI.Error(), interfaceIsNilErrMsg)
}
8 changes: 8 additions & 0 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type testTypesRow struct {
List_col Composite[[]int32]
Struct_col Composite[testTypesStruct]
Map_col Map
Array_col Composite[[3]int32]
Time_tz_col time.Time
Timestamp_tz_col time.Time
}
Expand Down Expand Up @@ -82,6 +83,7 @@ const testTypesTableSQL = `CREATE TABLE test (
List_col INTEGER[],
Struct_col STRUCT(A INTEGER, B VARCHAR),
Map_col MAP(INTEGER, VARCHAR),
Array_col INTEGER[3],
Time_tz_col TIMETZ,
Timestamp_tz_col TIMESTAMPTZ
)`
Expand Down Expand Up @@ -124,6 +126,9 @@ func testTypesGenerateRow[T require.TestingT](t T, i int) testTypesRow {
mapCol := Map{
int32(i): "other_longer_val",
}
arrayCol := Composite[[3]int32]{
[3]int32{int32(i), int32(i), int32(i)},
}

return testTypesRow{
i%2 == 1,
Expand Down Expand Up @@ -151,6 +156,7 @@ func testTypesGenerateRow[T require.TestingT](t T, i int) testTypesRow {
listCol,
structCol,
mapCol,
arrayCol,
timeTZ,
ts,
}
Expand Down Expand Up @@ -200,6 +206,7 @@ func testTypes[T require.TestingT](t T, c *Connector, a *Appender, expectedRows
r.List_col.Get(),
r.Struct_col.Get(),
r.Map_col,
r.Array_col.Get(),
r.Time_tz_col,
r.Timestamp_tz_col)
require.NoError(t, err)
Expand Down Expand Up @@ -239,6 +246,7 @@ func testTypes[T require.TestingT](t T, c *Connector, a *Appender, expectedRows
&r.List_col,
&r.Struct_col,
&r.Map_col,
&r.Array_col,
&r.Time_tz_col,
&r.Timestamp_tz_col)
require.NoError(t, err)
Expand Down
Loading

0 comments on commit e1b55d8

Please sign in to comment.