Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

util/parquet: add support for arrays #101860

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 58 additions & 26 deletions pkg/util/parquet/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ import (
"github.com/lib/pq/oid"
)

// Setting parquet.Repetitions.Optional makes parquet a column nullable. When
// writing a datum, we will always specify a definition level to indicate if the
// datum is null or not. See comments on nonNilDefLevel or nilDefLevel for more info.
var defaultRepetitions = parquet.Repetitions.Optional

// A schema field is an internal identifier for schema nodes used by the parquet library.
// A value of -1 will let the library auto-assign values. This does not affect reading
// or writing parquet files.
Expand All @@ -36,7 +41,7 @@ const defaultTypeLength = -1
// A column stores column metadata.
type column struct {
node schema.Node
colWriter writeFn
colWriter colWriter
decoder decoder
typ *types.T
}
Expand Down Expand Up @@ -67,7 +72,7 @@ func NewSchema(columnNames []string, columnTypes []*types.T) (*SchemaDefinition,
fields := make([]schema.Node, 0)

for i := 0; i < len(columnNames); i++ {
parquetCol, err := makeColumn(columnNames[i], columnTypes[i])
parquetCol, err := makeColumn(columnNames[i], columnTypes[i], defaultRepetitions)
if err != nil {
return nil, err
}
Expand All @@ -87,50 +92,44 @@ func NewSchema(columnNames []string, columnTypes []*types.T) (*SchemaDefinition,
}

// makeColumn constructs a column.
func makeColumn(colName string, typ *types.T) (column, error) {
// Setting parquet.Repetitions.Optional makes parquet interpret all columns as nullable.
// When writing data, we will specify a definition level of 0 (null) or 1 (not null).
// See https://blog.twitter.com/engineering/en_us/a/2013/dremel-made-simple-with-parquet
// for more information regarding definition levels.
defaultRepetitions := parquet.Repetitions.Optional

func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (column, error) {
result := column{typ: typ}
var err error
switch typ.Family() {
case types.BoolFamily:
result.node = schema.NewBooleanNode(colName, defaultRepetitions, defaultSchemaFieldID)
result.colWriter = writeBool
result.node = schema.NewBooleanNode(colName, repetitions, defaultSchemaFieldID)
result.colWriter = scalarWriter(writeBool)
result.decoder = boolDecoder{}
result.typ = types.Bool
return result, nil
case types.StringFamily:
result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.StringLogicalType{}, parquet.Types.ByteArray,
repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray,
defaultTypeLength, defaultSchemaFieldID)

if err != nil {
return result, err
}
result.colWriter = writeString
result.colWriter = scalarWriter(writeString)
result.decoder = stringDecoder{}
return result, nil
case types.IntFamily:
// Note: integer datums are always signed: https://www.cockroachlabs.com/docs/stable/int.html
if typ.Oid() == oid.T_int8 {
result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.NewIntLogicalType(64, true),
repetitions, schema.NewIntLogicalType(64, true),
parquet.Types.Int64, defaultTypeLength,
defaultSchemaFieldID)
if err != nil {
return result, err
}
result.colWriter = writeInt64
result.colWriter = scalarWriter(writeInt64)
result.decoder = int64Decoder{}
return result, nil
}

result.node = schema.NewInt32Node(colName, defaultRepetitions, defaultSchemaFieldID)
result.colWriter = writeInt32
result.node = schema.NewInt32Node(colName, repetitions, defaultSchemaFieldID)
result.colWriter = scalarWriter(writeInt32)
result.decoder = int32Decoder{}
return result, nil
case types.DecimalFamily:
Expand All @@ -149,37 +148,71 @@ func makeColumn(colName string, typ *types.T) (column, error) {
}

result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.NewDecimalLogicalType(precision,
repetitions, schema.NewDecimalLogicalType(precision,
scale), parquet.Types.ByteArray, defaultTypeLength,
defaultSchemaFieldID)
if err != nil {
return result, err
}
result.colWriter = writeDecimal
result.colWriter = scalarWriter(writeDecimal)
result.decoder = decimalDecoder{}
return result, nil
case types.UuidFamily:
result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.UUIDLogicalType{},
repetitions, schema.UUIDLogicalType{},
parquet.Types.FixedLenByteArray, uuid.Size, defaultSchemaFieldID)
if err != nil {
return result, err
}
result.colWriter = writeUUID
result.colWriter = scalarWriter(writeUUID)
result.decoder = uUIDDecoder{}
return result, nil
case types.TimestampFamily:
// Note that all timestamp datums are in UTC: https://www.cockroachlabs.com/docs/stable/timestamp.html
result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.StringLogicalType{}, parquet.Types.ByteArray,
repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray,
defaultTypeLength, defaultSchemaFieldID)
if err != nil {
return result, err
}

result.colWriter = writeTimestamp
result.colWriter = scalarWriter(writeTimestamp)
result.decoder = timestampDecoder{}
return result, nil
case types.ArrayFamily:
// Arrays for type T are represented by the following:
// message schema { -- toplevel schema
// optional group a (LIST) { -- list column
// repeated group list {
// optional T element;
// }
// }
// }
// Representing arrays this way makes it easier to differentiate NULL, [NULL],
// and [] when encoding.
// There is more info about encoding arrays here:
// https://arrow.apache.org/blog/2022/10/08/arrow-parquet-encoding-part-2/
elementCol, err := makeColumn("element", typ.ArrayContents(), parquet.Repetitions.Optional)
if err != nil {
return result, err
}
innerListFields := []schema.Node{elementCol.node}
innerListNode, err := schema.NewGroupNode("list", parquet.Repetitions.Repeated, innerListFields, defaultSchemaFieldID)
if err != nil {
return result, err
}
outerListFields := []schema.Node{innerListNode}
result.node, err = schema.NewGroupNodeLogical(colName, parquet.Repetitions.Optional, outerListFields, schema.ListLogicalType{}, defaultSchemaFieldID)
if err != nil {
return result, err
}
result.decoder = elementCol.decoder
scalarColWriter, ok := elementCol.colWriter.(scalarWriter)
if !ok {
return result, errors.AssertionFailedf("expected scalar column writer")
}
result.colWriter = arrayWriter(scalarColWriter)
result.typ = elementCol.typ
return result, nil

// TODO(#99028): implement support for the remaining types.
// case types.INetFamily:
Expand All @@ -196,8 +229,7 @@ func makeColumn(colName string, typ *types.T) (column, error) {
// case types.TimeTZFamily:
// case types.IntervalFamily:
// case types.TimestampTZFamily:
// case types.ArrayFamily:
default:
return result, pgerror.Newf(pgcode.FeatureNotSupported, "parquet export does not support the %v type yet", typ.Family())
return result, pgerror.Newf(pgcode.FeatureNotSupported, "parquet export does not support the %v type", typ.Family())
}
}
162 changes: 113 additions & 49 deletions pkg/util/parquet/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/apache/arrow/go/v11/parquet"
"github.com/apache/arrow/go/v11/parquet/file"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -55,41 +56,51 @@ func ReadFileAndVerifyDatums(
for rg := 0; rg < reader.NumRowGroups(); rg++ {
rgr := reader.RowGroup(rg)
rowsInRowGroup := rgr.NumRows()
defLevels := make([]int16, rowsInRowGroup)

for colIdx := 0; colIdx < numCols; colIdx++ {
col, err := rgr.Column(colIdx)
require.NoError(t, err)

dec := writer.sch.cols[colIdx].decoder
typ := writer.sch.cols[colIdx].typ

// Based on how we define schemas, we can detect an array by seeing if the
// primitive col reader has a max repetition level of 1. See comments above
// arrayEntryRepLevel for more info.
isArray := col.Descriptor().MaxRepetitionLevel() == 1

switch col.Type() {
case parquet.Types.Boolean:
values := make([]bool, rowsInRowGroup)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]bool, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
case parquet.Types.Int32:
values := make([]int32, numRows)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]int32, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
case parquet.Types.Int64:
values := make([]int64, rowsInRowGroup)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]int64, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
case parquet.Types.Int96:
panic("unimplemented")
case parquet.Types.Float:
panic("unimplemented")
case parquet.Types.Double:
panic("unimplemented")
case parquet.Types.ByteArray:
values := make([]parquet.ByteArray, rowsInRowGroup)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]parquet.ByteArray, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
case parquet.Types.FixedLenByteArray:
values := make([]parquet.FixedLenByteArray, rowsInRowGroup)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]parquet.FixedLenByteArray, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
}
}
startingRowIdx += int(rowsInRowGroup)
Expand All @@ -98,54 +109,107 @@ func ReadFileAndVerifyDatums(

for i := 0; i < numRows; i++ {
for j := 0; j < numCols; j++ {
assert.Equal(t, writtenDatums[i][j], readDatums[i][j])
validateDatum(t, writtenDatums[i][j], readDatums[i][j])
}
}
}

func readBatchHelper[T parquetDatatypes](
t *testing.T, r file.ColumnChunkReader, expectedrowsInRowGroup int64, values []T, defLvls []int16,
) {
numRead, err := readBatch(r, expectedrowsInRowGroup, values, defLvls)
require.NoError(t, err)
require.Equal(t, numRead, expectedrowsInRowGroup)
}

type batchReader[T parquetDatatypes] interface {
ReadBatch(batchSize int64, values []T, defLvls []int16, repLvls []int16) (total int64, valuesRead int, err error)
}

// readBatch reads all the datums in a row group for a column.
func readBatch[T parquetDatatypes](
r file.ColumnChunkReader, batchSize int64, values []T, defLvls []int16,
) (int64, error) {
r file.ColumnChunkReader, valueAlloc []T, dec decoder, typ *types.T, isArray bool,
) (tree.Datums, int64, error) {
br, ok := r.(batchReader[T])
if !ok {
return 0, errors.AssertionFailedf("expected batchReader of type %T, but found %T instead", values, r)
return nil, 0, errors.AssertionFailedf("expected batchReader for type %T, but found %T instead", valueAlloc, r)
}
numRowsRead, _, err := br.ReadBatch(batchSize, values, defLvls, nil)
return numRowsRead, err

result := make([]tree.Datum, 0)
defLevels := [1]int16{}
repLevels := [1]int16{}

for {
numRowsRead, _, err := br.ReadBatch(1, valueAlloc, defLevels[:], repLevels[:])
if err != nil {
return nil, 0, err
}
if numRowsRead == 0 {
break
}

if isArray {
// Replevel 0 indicates the start of a new array.
if repLevels[0] == 0 {
// Replevel 0, Deflevel 0 represents a NULL array.
if defLevels[0] == 0 {
result = append(result, tree.DNull)
continue
}
arrDatum := tree.NewDArray(typ)
result = append(result, arrDatum)
// Replevel 0, Deflevel 1 represents an array which is empty.
if defLevels[0] == 1 {
continue
}
}
currentArrayDatum := result[len(result)-1].(*tree.DArray)
// Deflevel 2 represents a null value in an array.
if defLevels[0] == 2 {
currentArrayDatum.Array = append(currentArrayDatum.Array, tree.DNull)
continue
}
// Deflevel 3 represents a non-null datum in an array.
d, err := decode(dec, valueAlloc[0])
if err != nil {
return nil, 0, err
}
currentArrayDatum.Array = append(currentArrayDatum.Array, d)
} else {
// Deflevel 0 represents a null value
// Deflevel 1 represents a non-null value
d := tree.DNull
if defLevels[0] != 0 {
d, err = decode(dec, valueAlloc[0])
if err != nil {
return nil, 0, err
}
}
result = append(result, d)
}
}

return result, int64(len(result)), nil
}

func decodeValuesIntoDatumsHelper[T parquetDatatypes](
t *testing.T,
datums [][]tree.Datum,
colIdx int,
startingRowIdx int,
dec decoder,
values []T,
defLevels []int16,
func decodeValuesIntoDatumsHelper(
colDatums []tree.Datum, datumRows [][]tree.Datum, colIdx int, startingRowIdx int,
) {
var err error
// If the defLevel of a datum is 0, parquet will not write it to the column.
// Use valueReadIdx to only read from the front of the values array, where datums are defined.
valueReadIdx := 0
for rowOffset, defLevel := range defLevels {
d := tree.DNull
if defLevel != 0 {
d, err = decode(dec, values[valueReadIdx])
require.NoError(t, err)
valueReadIdx++
for rowOffset, datum := range colDatums {
datumRows[startingRowIdx+rowOffset][colIdx] = datum
}
}

// validateDatum validates that the "contents" of the expected datum matches the
// contents of the actual datum. For example, when validating two arrays, we
// only compare the datums in the arrays. We do not compare CRDB-specific
// metadata fields such as (tree.DArray).HasNulls or (tree.DArray).HasNonNulls.
//
// The reason for this special comparison that the parquet format is presently
// used in an export use case, so we only need to test roundtripability with
// data end users see. We do not need to check roundtripability of internal CRDB data.
func validateDatum(t *testing.T, expected tree.Datum, actual tree.Datum) {
switch expected.ResolvedType().Family() {
case types.ArrayFamily:
arr1 := expected.(*tree.DArray).Array
arr2 := actual.(*tree.DArray).Array
require.Equal(t, len(arr1), len(arr2))
for i := 0; i < len(arr1); i++ {
validateDatum(t, arr1[i], arr2[i])
}
datums[startingRowIdx+rowOffset][colIdx] = d
default:
require.Equal(t, expected, actual)
}
}
Loading