From adf8c1a1f4fb958e669c2495411d69016087e758 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Wed, 19 Apr 2023 11:29:25 -0400 Subject: [PATCH] util/parquet: add support for arrays This change extends and refactors the util/parquet library to be able to read and write arrays. Release note: None Informs: https://github.com/cockroachdb/cockroach/issues/99028 Epic: https://cockroachlabs.atlassian.net/browse/CRDB-15071 --- pkg/util/parquet/schema.go | 95 +++++++++---- pkg/util/parquet/testutils.go | 162 ++++++++++++++++------- pkg/util/parquet/write_functions.go | 184 +++++++++++++++++++++----- pkg/util/parquet/writer.go | 8 +- pkg/util/parquet/writer_bench_test.go | 2 +- pkg/util/parquet/writer_test.go | 42 +++++- 6 files changed, 367 insertions(+), 126 deletions(-) diff --git a/pkg/util/parquet/schema.go b/pkg/util/parquet/schema.go index 591f3ba87c41..e7476e2163a1 100644 --- a/pkg/util/parquet/schema.go +++ b/pkg/util/parquet/schema.go @@ -23,6 +23,11 @@ import ( "github.com/lib/pq/oid" ) +// Setting parquet.Repetitions.Optional makes parquet a column nullable. When +// writing a datum, we will always specify a definition level to indicate if the +// datum is null or not. See comments on nonNilDefLevel or nilDefLevel for more info. +var defaultRepetitions = parquet.Repetitions.Optional + // A schema field is an internal identifier for schema nodes used by the parquet library. // A value of -1 will let the library auto-assign values. This does not affect reading // or writing parquet files. @@ -35,10 +40,11 @@ const defaultTypeLength = -1 // A column stores column metadata. type column struct { - node schema.Node - colWriter writeFn - decoder decoder - typ *types.T + node schema.Node + writeInvoker writeInvoker + writeFn writeFn + decoder decoder + typ *types.T } // A SchemaDefinition stores a parquet schema. @@ -67,7 +73,7 @@ func NewSchema(columnNames []string, columnTypes []*types.T) (*SchemaDefinition, fields := make([]schema.Node, 0) for i := 0; i < len(columnNames); i++ { - parquetCol, err := makeColumn(columnNames[i], columnTypes[i]) + parquetCol, err := makeColumn(columnNames[i], columnTypes[i], defaultRepetitions) if err != nil { return nil, err } @@ -87,50 +93,48 @@ func NewSchema(columnNames []string, columnTypes []*types.T) (*SchemaDefinition, } // makeColumn constructs a column. -func makeColumn(colName string, typ *types.T) (column, error) { - // Setting parquet.Repetitions.Optional makes parquet interpret all columns as nullable. - // When writing data, we will specify a definition level of 0 (null) or 1 (not null). - // See https://blog.twitter.com/engineering/en_us/a/2013/dremel-made-simple-with-parquet - // for more information regarding definition levels. - defaultRepetitions := parquet.Repetitions.Optional - +func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (column, error) { result := column{typ: typ} var err error switch typ.Family() { case types.BoolFamily: - result.node = schema.NewBooleanNode(colName, defaultRepetitions, defaultSchemaFieldID) - result.colWriter = writeBool + result.node = schema.NewBooleanNode(colName, repetitions, defaultSchemaFieldID) + result.writeInvoker = writeScalar + result.writeFn = writeBool result.decoder = boolDecoder{} result.typ = types.Bool return result, nil case types.StringFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, - defaultRepetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, defaultTypeLength, defaultSchemaFieldID) if err != nil { return result, err } - result.colWriter = writeString + result.writeInvoker = writeScalar + result.writeFn = writeString result.decoder = stringDecoder{} return result, nil case types.IntFamily: // Note: integer datums are always signed: https://www.cockroachlabs.com/docs/stable/int.html if typ.Oid() == oid.T_int8 { result.node, err = schema.NewPrimitiveNodeLogical(colName, - defaultRepetitions, schema.NewIntLogicalType(64, true), + repetitions, schema.NewIntLogicalType(64, true), parquet.Types.Int64, defaultTypeLength, defaultSchemaFieldID) if err != nil { return result, err } - result.colWriter = writeInt64 + result.writeInvoker = writeScalar + result.writeFn = writeInt64 result.decoder = int64Decoder{} return result, nil } - result.node = schema.NewInt32Node(colName, defaultRepetitions, defaultSchemaFieldID) - result.colWriter = writeInt32 + result.node = schema.NewInt32Node(colName, repetitions, defaultSchemaFieldID) + result.writeInvoker = writeScalar + result.writeFn = writeInt32 result.decoder = int32Decoder{} return result, nil case types.DecimalFamily: @@ -149,37 +153,71 @@ func makeColumn(colName string, typ *types.T) (column, error) { } result.node, err = schema.NewPrimitiveNodeLogical(colName, - defaultRepetitions, schema.NewDecimalLogicalType(precision, + repetitions, schema.NewDecimalLogicalType(precision, scale), parquet.Types.ByteArray, defaultTypeLength, defaultSchemaFieldID) if err != nil { return result, err } - result.colWriter = writeDecimal + result.writeInvoker = writeScalar + result.writeFn = writeDecimal result.decoder = decimalDecoder{} return result, nil case types.UuidFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, - defaultRepetitions, schema.UUIDLogicalType{}, + repetitions, schema.UUIDLogicalType{}, parquet.Types.FixedLenByteArray, uuid.Size, defaultSchemaFieldID) if err != nil { return result, err } - result.colWriter = writeUUID + result.writeInvoker = writeScalar + result.writeFn = writeUUID result.decoder = uUIDDecoder{} return result, nil case types.TimestampFamily: // Note that all timestamp datums are in UTC: https://www.cockroachlabs.com/docs/stable/timestamp.html result.node, err = schema.NewPrimitiveNodeLogical(colName, - defaultRepetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, defaultTypeLength, defaultSchemaFieldID) if err != nil { return result, err } - - result.colWriter = writeTimestamp + result.writeInvoker = writeScalar + result.writeFn = writeTimestamp result.decoder = timestampDecoder{} return result, nil + case types.ArrayFamily: + // Arrays for type T are represented by the following: + // message schema { -- toplevel schema + // optional group a (LIST) { -- list column + // repeated group list { + // optional T element; + // } + // } + // } + // Representing arrays this way makes it easier to differentiate NULL, [NULL], + // and [] when encoding. + // There is more info about encoding arrays here: + // https://arrow.apache.org/blog/2022/10/08/arrow-parquet-encoding-part-2/ + elementCol, err := makeColumn("element", typ.ArrayContents(), parquet.Repetitions.Optional) + if err != nil { + return result, err + } + innerListFields := []schema.Node{elementCol.node} + innerListNode, err := schema.NewGroupNode("list", parquet.Repetitions.Repeated, innerListFields, defaultSchemaFieldID) + if err != nil { + return result, err + } + outerListFields := []schema.Node{innerListNode} + result.node, err = schema.NewGroupNodeLogical(colName, parquet.Repetitions.Optional, outerListFields, schema.ListLogicalType{}, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.decoder = elementCol.decoder + result.writeInvoker = writeArray + result.writeFn = elementCol.writeFn + result.typ = elementCol.typ + return result, nil // TODO(#99028): implement support for the remaining types. // case types.INetFamily: @@ -196,8 +234,7 @@ func makeColumn(colName string, typ *types.T) (column, error) { // case types.TimeTZFamily: // case types.IntervalFamily: // case types.TimestampTZFamily: - // case types.ArrayFamily: default: - return result, pgerror.Newf(pgcode.FeatureNotSupported, "parquet export does not support the %v type yet", typ.Family()) + return result, pgerror.Newf(pgcode.FeatureNotSupported, "parquet export does not support the %v type", typ.Family()) } } diff --git a/pkg/util/parquet/testutils.go b/pkg/util/parquet/testutils.go index e0b22be47ed6..44428bc27bec 100644 --- a/pkg/util/parquet/testutils.go +++ b/pkg/util/parquet/testutils.go @@ -18,6 +18,7 @@ import ( "github.com/apache/arrow/go/v11/parquet" "github.com/apache/arrow/go/v11/parquet/file" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -55,27 +56,35 @@ func ReadFileAndVerifyDatums( for rg := 0; rg < reader.NumRowGroups(); rg++ { rgr := reader.RowGroup(rg) rowsInRowGroup := rgr.NumRows() - defLevels := make([]int16, rowsInRowGroup) for colIdx := 0; colIdx < numCols; colIdx++ { col, err := rgr.Column(colIdx) require.NoError(t, err) dec := writer.sch.cols[colIdx].decoder + typ := writer.sch.cols[colIdx].typ + + // Based on how we define schemas, we can detect an array by seeing if the + // primitive col reader has a max repetition level of 1. See comments above + // arrayEntryRepLevel for more info. + isArray := col.Descriptor().MaxRepetitionLevel() == 1 switch col.Type() { case parquet.Types.Boolean: - values := make([]bool, rowsInRowGroup) - readBatchHelper(t, col, rowsInRowGroup, values, defLevels) - decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + colDatums, read, err := readBatch(col, make([]bool, 1), dec, typ, isArray) + require.NoError(t, err) + require.Equal(t, rowsInRowGroup, read) + decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) case parquet.Types.Int32: - values := make([]int32, numRows) - readBatchHelper(t, col, rowsInRowGroup, values, defLevels) - decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + colDatums, read, err := readBatch(col, make([]int32, 1), dec, typ, isArray) + require.NoError(t, err) + require.Equal(t, rowsInRowGroup, read) + decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) case parquet.Types.Int64: - values := make([]int64, rowsInRowGroup) - readBatchHelper(t, col, rowsInRowGroup, values, defLevels) - decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + colDatums, read, err := readBatch(col, make([]int64, 1), dec, typ, isArray) + require.NoError(t, err) + require.Equal(t, rowsInRowGroup, read) + decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) case parquet.Types.Int96: panic("unimplemented") case parquet.Types.Float: @@ -83,13 +92,15 @@ func ReadFileAndVerifyDatums( case parquet.Types.Double: panic("unimplemented") case parquet.Types.ByteArray: - values := make([]parquet.ByteArray, rowsInRowGroup) - readBatchHelper(t, col, rowsInRowGroup, values, defLevels) - decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + colDatums, read, err := readBatch(col, make([]parquet.ByteArray, 1), dec, typ, isArray) + require.NoError(t, err) + require.Equal(t, rowsInRowGroup, read) + decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) case parquet.Types.FixedLenByteArray: - values := make([]parquet.FixedLenByteArray, rowsInRowGroup) - readBatchHelper(t, col, rowsInRowGroup, values, defLevels) - decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + colDatums, read, err := readBatch(col, make([]parquet.FixedLenByteArray, 1), dec, typ, isArray) + require.NoError(t, err) + require.Equal(t, rowsInRowGroup, read) + decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) } } startingRowIdx += int(rowsInRowGroup) @@ -98,54 +109,107 @@ func ReadFileAndVerifyDatums( for i := 0; i < numRows; i++ { for j := 0; j < numCols; j++ { - assert.Equal(t, writtenDatums[i][j], readDatums[i][j]) + validateDatum(t, writtenDatums[i][j], readDatums[i][j]) } } } -func readBatchHelper[T parquetDatatypes]( - t *testing.T, r file.ColumnChunkReader, expectedrowsInRowGroup int64, values []T, defLvls []int16, -) { - numRead, err := readBatch(r, expectedrowsInRowGroup, values, defLvls) - require.NoError(t, err) - require.Equal(t, numRead, expectedrowsInRowGroup) -} - type batchReader[T parquetDatatypes] interface { ReadBatch(batchSize int64, values []T, defLvls []int16, repLvls []int16) (total int64, valuesRead int, err error) } +// readBatch reads all the datums in a row group for a column. func readBatch[T parquetDatatypes]( - r file.ColumnChunkReader, batchSize int64, values []T, defLvls []int16, -) (int64, error) { + r file.ColumnChunkReader, valueAlloc []T, dec decoder, typ *types.T, isArray bool, +) (tree.Datums, int64, error) { br, ok := r.(batchReader[T]) if !ok { - return 0, errors.AssertionFailedf("expected batchReader of type %T, but found %T instead", values, r) + return nil, 0, errors.AssertionFailedf("expected batchReader for type %T, but found %T instead", valueAlloc, r) } - numRowsRead, _, err := br.ReadBatch(batchSize, values, defLvls, nil) - return numRowsRead, err + + result := make([]tree.Datum, 0) + defLevels := [1]int16{} + repLevels := [1]int16{} + + for { + numRowsRead, _, err := br.ReadBatch(1, valueAlloc, defLevels[:], repLevels[:]) + if err != nil { + return nil, 0, err + } + if numRowsRead == 0 { + break + } + + if isArray { + // Replevel 0 indicates the start of a new array. + if repLevels[0] == 0 { + // Replevel 0, Deflevel 0 represents a NULL array. + if defLevels[0] == 0 { + result = append(result, tree.DNull) + continue + } + arrDatum := tree.NewDArray(typ) + result = append(result, arrDatum) + // Replevel 0, Deflevel 1 represents an array which is empty. + if defLevels[0] == 1 { + continue + } + } + currentArrayDatum := result[len(result)-1].(*tree.DArray) + // Deflevel 2 represents a null value in an array. + if defLevels[0] == 2 { + currentArrayDatum.Array = append(currentArrayDatum.Array, tree.DNull) + continue + } + // Deflevel 3 represents a non-null datum in an array. + d, err := decode(dec, valueAlloc[0]) + if err != nil { + return nil, 0, err + } + currentArrayDatum.Array = append(currentArrayDatum.Array, d) + } else { + // Deflevel 0 represents a null value + // Deflevel 1 represents a non-null value + d := tree.DNull + if defLevels[0] != 0 { + d, err = decode(dec, valueAlloc[0]) + if err != nil { + return nil, 0, err + } + } + result = append(result, d) + } + } + + return result, int64(len(result)), nil } -func decodeValuesIntoDatumsHelper[T parquetDatatypes]( - t *testing.T, - datums [][]tree.Datum, - colIdx int, - startingRowIdx int, - dec decoder, - values []T, - defLevels []int16, +func decodeValuesIntoDatumsHelper( + colDatums []tree.Datum, datumRows [][]tree.Datum, colIdx int, startingRowIdx int, ) { - var err error - // If the defLevel of a datum is 0, parquet will not write it to the column. - // Use valueReadIdx to only read from the front of the values array, where datums are defined. - valueReadIdx := 0 - for rowOffset, defLevel := range defLevels { - d := tree.DNull - if defLevel != 0 { - d, err = decode(dec, values[valueReadIdx]) - require.NoError(t, err) - valueReadIdx++ + for rowOffset, datum := range colDatums { + datumRows[startingRowIdx+rowOffset][colIdx] = datum + } +} + +// validateDatum validates that the "contents" of the expected datum matches the +// contents of the actual datum. For example, when validating two arrays, we +// only compare the datums in the arrays. We do not compare CRDB-specific +// metadata fields such as (tree.DArray).HasNulls or (tree.DArray).HasNonNulls. +// +// The reason for this special comparison that the parquet format is presently +// used in an export use case, so we only need to test roundtripability with +// data end users see. We do not need to check roundtripability of internal CRDB data. +func validateDatum(t *testing.T, expected tree.Datum, actual tree.Datum) { + switch expected.ResolvedType().Family() { + case types.ArrayFamily: + arr1 := expected.(*tree.DArray).Array + arr2 := actual.(*tree.DArray).Array + require.Equal(t, len(arr1), len(arr2)) + for i := 0; i < len(arr1); i++ { + validateDatum(t, arr1[i], arr2[i]) } - datums[startingRowIdx+rowOffset][colIdx] = d + default: + require.Equal(t, expected, actual) } } diff --git a/pkg/util/parquet/write_functions.go b/pkg/util/parquet/write_functions.go index 021452fd2e90..a4e049d92168 100644 --- a/pkg/util/parquet/write_functions.go +++ b/pkg/util/parquet/write_functions.go @@ -44,53 +44,164 @@ type batchAlloc struct { fixedLenByteArrayBatch [1]parquet.FixedLenByteArray } +// The following variables are used when writing datums which are not in arrays. +// // nonNilDefLevel represents a definition level of 1, meaning that the value is non-nil. // nilDefLevel represents a definition level of 0, meaning that the value is nil. +// Any corresponding repetition level should be 0 as nonzero repetition levels are only valid for +// arrays in this library. // // For more info on definition levels, refer to -// https://github.com/apache/parquet-format/blob/master/README.md#nested-encoding. +// https://arrow.apache.org/blog/2022/10/05/arrow-parquet-encoding-part-1/ var nonNilDefLevel = []int16{1} var nilDefLevel = []int16{0} +// The following variables are used when writing datums which are in arrays. This explanation +// is valid for the array schema constructed in makeColumn. +// +// In summary: +// - def level 0 means the array is null +// - def level 1 means the array is not null, but is empty. +// - def level 2 means the array is not null, and contains a null datum +// - def level 3 means the array is not null, and contains a non-null datum +// - rep level 0 indicates the start of a new array (which may be null or non-null depending on the def level) +// - rep level 1 indicates that we are writing to an existing array +// +// Examples: +// +// Null Array +// d := tree.DNull +// writeFn(tree.DNull, ..., defLevels = [0], repLevels = [0]) +// +// Empty Array +// d := tree.NewDArray(types.Int) +// d.Array = tree.Datums{} +// writeFn(tree.DNull, ..., defLevels = [1], repLevels = [0]) +// +// # Multiple, Typical Arrays +// +// d := tree.NewDArray(types.Int) +// d.Array = tree.Datums{1, 2, NULL, 3, 4} +// d2 := tree.NewDArray(types.Int) +// d2.Array = tree.Datums{1, 1} +// writeFn(d.Array[0], ..., defLevels = [3], repLevels = [0]) -- repLevel 0 indicates the start of an array +// writeFn(d.Array[1], ..., defLevels = [3], repLevels = [1]) -- repLevel 1 writes the datum in the array +// writeFn(tree.DNull, ..., defLevels = [2], repLevels = [1]) -- defLevel 2 indicates a null datum +// writeFn(d.Array[3], ..., defLevels = [3], repLevels = [1]) +// writeFn(d.Array[4], ..., defLevels = [3], repLevels = [1]) +// +// writeFn(d2.Array[0], ..., defLevels = [3], repLevels = [0]) -- repLevel 0 indicates the start of a new array +// writeFn(d2.Array[1], ..., defLevels = [3], repLevels = [1]) +// +// For more info on definition levels and repetition levels, refer to +// https://arrow.apache.org/blog/2022/10/08/arrow-parquet-encoding-part-2/ +var newEntryRepLevel = []int16{0} +var arrayEntryRepLevel = []int16{1} +var nilArrayDefLevel = []int16{0} +var zeroLengthArrayDefLevel = []int16{1} +var arrayEntryNilDefLevel = []int16{2} +var arrayEntryNonNilDefLevel = []int16{3} + +func writeScalar(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, wFn writeFn) error { + if d == tree.DNull { + if err := wFn(tree.DNull, w, a, nilDefLevel, newEntryRepLevel); err != nil { + return err + } + return nil + } + + if err := wFn(d, w, a, nonNilDefLevel, newEntryRepLevel); err != nil { + return err + } + return nil +} + +func writeArray(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, wFn writeFn) error { + if d == tree.DNull { + return wFn(tree.DNull, w, a, nilArrayDefLevel, newEntryRepLevel) + } + di, ok := tree.AsDArray(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DArray, found %T", d) + } + if len(di.Array) == 0 { + return wFn(tree.DNull, w, a, zeroLengthArrayDefLevel, newEntryRepLevel) + } + + repLevel := newEntryRepLevel + for i, childDatum := range di.Array { + if i == 1 { + repLevel = arrayEntryRepLevel + } + if childDatum == tree.DNull { + if err := wFn(childDatum, w, a, arrayEntryNilDefLevel, repLevel); err != nil { + return err + } + } else { + if err := wFn(childDatum, w, a, arrayEntryNonNilDefLevel, repLevel); err != nil { + return err + } + } + } + return nil +} + +// A writeInvoker invokes a writeFn with the correct repetition levels and definition levels. +type writeInvoker func(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, wFn writeFn) error + // A writeFn encodes a datum and writes it using the provided column chunk writer. -type writeFn func(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error +// +// The caller is responsible for ensuring that the parameters are correct, since they are co-dependent; +// see executeWriteFn and defLevel/repLevel commentary above. +type writeFn func(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16) error -func writeInt32(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error { +func writeInt32( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, +) error { if d == tree.DNull { - return writeNilBatch[int32](w) + return writeBatch[int32](w, a.int32Batch[:], defLevels, repLevels) } di, ok := tree.AsDInt(d) if !ok { return pgerror.Newf(pgcode.DatatypeMismatch, "expected DInt, found %T", d) } - return writeBatch[int32](w, a.int32Batch[:], int32(di)) + a.int32Batch[0] = int32(di) + return writeBatch[int32](w, a.int32Batch[:], defLevels, repLevels) } -func writeInt64(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error { +func writeInt64( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, +) error { if d == tree.DNull { - return writeNilBatch[int64](w) + return writeBatch[int64](w, a.int64Batch[:], defLevels, repLevels) } di, ok := tree.AsDInt(d) if !ok { return pgerror.Newf(pgcode.DatatypeMismatch, "expected DInt, found %T", d) } - return writeBatch[int64](w, a.int64Batch[:], int64(di)) + a.int64Batch[0] = int64(di) + return writeBatch[int64](w, a.int64Batch[:], defLevels, repLevels) } -func writeBool(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error { +func writeBool( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, +) error { if d == tree.DNull { - return writeNilBatch[bool](w) + return writeBatch[bool](w, a.boolBatch[:], defLevels, repLevels) } di, ok := tree.AsDBool(d) if !ok { return pgerror.Newf(pgcode.DatatypeMismatch, "expected DBool, found %T", d) } - return writeBatch[bool](w, a.boolBatch[:], bool(di)) + a.boolBatch[0] = bool(di) + return writeBatch[bool](w, a.boolBatch[:], defLevels, repLevels) } -func writeString(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error { +func writeString( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, +) error { if d == tree.DNull { - return writeNilBatch[parquet.ByteArray](w) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } di, ok := tree.AsDString(d) if !ok { @@ -101,7 +212,8 @@ func writeString(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error { if err != nil { return err } - return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], b) + a.byteArrayBatch[0] = b + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } // unsafeGetBytes returns []byte in the underlying string, @@ -131,9 +243,11 @@ func unsafeGetBytes(s string) ([]byte, error) { return (*[maxStrLen]byte)(p)[:len(s):len(s)], nil } -func writeTimestamp(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error { +func writeTimestamp( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, +) error { if d == tree.DNull { - return writeNilBatch[parquet.ByteArray](w) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } _, ok := tree.AsDTimestamp(d) @@ -144,30 +258,37 @@ func writeTimestamp(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error fmtCtx := tree.NewFmtCtx(tree.FmtBareStrings) d.Format(fmtCtx) - return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], parquet.ByteArray(fmtCtx.CloseAndGetString())) + a.byteArrayBatch[0] = parquet.ByteArray(fmtCtx.CloseAndGetString()) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } -func writeUUID(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error { +func writeUUID( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, +) error { if d == tree.DNull { - return writeNilBatch[parquet.FixedLenByteArray](w) + return writeBatch[parquet.FixedLenByteArray](w, a.fixedLenByteArrayBatch[:], defLevels, repLevels) } di, ok := tree.AsDUuid(d) if !ok { return pgerror.Newf(pgcode.DatatypeMismatch, "expected DUuid, found %T", d) } - return writeBatch[parquet.FixedLenByteArray](w, a.fixedLenByteArrayBatch[:], di.UUID.GetBytes()) + a.fixedLenByteArrayBatch[0] = di.UUID.GetBytes() + return writeBatch[parquet.FixedLenByteArray](w, a.fixedLenByteArrayBatch[:], defLevels, repLevels) } -func writeDecimal(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error { +func writeDecimal( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, +) error { if d == tree.DNull { - return writeNilBatch[parquet.ByteArray](w) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } di, ok := tree.AsDDecimal(d) if !ok { return pgerror.Newf(pgcode.DatatypeMismatch, "expected DDecimal, found %T", d) } - return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], parquet.ByteArray(di.String())) + a.byteArrayBatch[0] = parquet.ByteArray(di.String()) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } // parquetDatatypes are the physical types used in the parquet library. @@ -181,22 +302,13 @@ type batchWriter[T parquetDatatypes] interface { WriteBatch(values []T, defLevels, repLevels []int16) (valueOffset int64, err error) } -func writeBatch[T parquetDatatypes](w file.ColumnChunkWriter, batchAlloc []T, v T) (err error) { - bw, ok := w.(batchWriter[T]) - if !ok { - return errors.AssertionFailedf("expected batchWriter of type %T, but found %T instead", []T(nil), w) - } - - batchAlloc[0] = v - _, err = bw.WriteBatch(batchAlloc, nonNilDefLevel, nil) - return err -} - -func writeNilBatch[T parquetDatatypes](w file.ColumnChunkWriter) (err error) { +func writeBatch[T parquetDatatypes]( + w file.ColumnChunkWriter, batchAlloc []T, defLevels, repLevels []int16, +) (err error) { bw, ok := w.(batchWriter[T]) if !ok { return errors.AssertionFailedf("expected batchWriter of type %T, but found %T instead", []T(nil), w) } - _, err = bw.WriteBatch([]T(nil), nilDefLevel, nil) + _, err = bw.WriteBatch(batchAlloc, defLevels, repLevels) return err } diff --git a/pkg/util/parquet/writer.go b/pkg/util/parquet/writer.go index 53971d566a9c..c08fd1e8daf5 100644 --- a/pkg/util/parquet/writer.go +++ b/pkg/util/parquet/writer.go @@ -113,8 +113,7 @@ func (w *Writer) writeDatumToColChunk(d tree.Datum, colIdx int) error { return err } - err = w.sch.cols[colIdx].colWriter(d, cw, w.ba) - if err != nil { + if err = w.sch.cols[colIdx].writeInvoker(d, cw, w.ba, w.sch.cols[colIdx].writeFn); err != nil { return err } return nil @@ -138,11 +137,6 @@ func (w *Writer) AddRow(datums []tree.Datum) error { } for idx, d := range datums { - // Note that EquivalentOrNull only allows null equivalence if the receiver is null. - if !d.ResolvedType().EquivalentOrNull(w.sch.cols[idx].typ, false) { - return errors.AssertionFailedf("expected datum of type %s, but found datum"+ - " of type: %s at column index %d", w.sch.cols[idx].typ.Name(), d.ResolvedType().Name(), idx) - } if err := w.writeDatumToColChunk(d, idx); err != nil { return err } diff --git a/pkg/util/parquet/writer_bench_test.go b/pkg/util/parquet/writer_bench_test.go index a41ebecb0c21..fe185a6dafaf 100644 --- a/pkg/util/parquet/writer_bench_test.go +++ b/pkg/util/parquet/writer_bench_test.go @@ -43,7 +43,7 @@ func BenchmarkParquetWriter(b *testing.B) { datums[i] = tree.NewDString(string(p)) } - fileName := "BenchmarkParquetWriter" + fileName := "BenchmarkParquetWriter.parquet" f, err := os.CreateTemp("", fileName) require.NoError(b, err) diff --git a/pkg/util/parquet/writer_test.go b/pkg/util/parquet/writer_test.go index 834ae7a8fcba..a076eb08f5b6 100644 --- a/pkg/util/parquet/writer_test.go +++ b/pkg/util/parquet/writer_test.go @@ -39,6 +39,8 @@ func newColSchema(numCols int) *colSchema { } } +// TODO (jayant): once all types are supported, we can use randgen.SeedTypes +// instead of this array. var supportedTypes = []*types.T{ types.Int, types.Bool, @@ -48,6 +50,20 @@ var supportedTypes = []*types.T{ types.Timestamp, } +func init() { + // Include all array types which are arrays of the scalar types above. + var arrayTypes []*types.T + for oid := range types.ArrayOids { + arrayTyp := types.OidToType[oid] + for _, typ := range supportedTypes { + if arrayTyp.InternalType.ArrayContents == typ { + arrayTypes = append(arrayTypes, arrayTyp) + } + } + } + supportedTypes = append(supportedTypes, arrayTypes...) +} + func makeRandDatums(numRows int, sch *colSchema, rng *rand.Rand) [][]tree.Datum { datums := make([][]tree.Datum, numRows) for i := 0; i < numRows; i++ { @@ -80,7 +96,7 @@ func TestRandomDatums(t *testing.T) { sch := makeRandSchema(numCols, rng) datums := makeRandDatums(numRows, sch, rng) - fileName := "TestRandomDatums" + fileName := "TestRandomDatums.parquet" f, err := os.CreateTemp("", fileName) require.NoError(t, err) @@ -199,6 +215,24 @@ func TestBasicDatums(t *testing.T) { }, nil }, }, + { + name: "array", + sch: &colSchema{ + columnTypes: []*types.T{types.IntArray, types.IntArray}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + da := tree.NewDArray(types.Int) + da.Array = tree.Datums{tree.NewDInt(0), tree.NewDInt(1)} + da2 := tree.NewDArray(types.Int) + da2.Array = tree.Datums{tree.NewDInt(2), tree.DNull} + da3 := tree.NewDArray(types.Int) + da3.Array = tree.Datums{} + return [][]tree.Datum{ + {da, da2}, {da3, tree.DNull}, + }, nil + }, + }, } { t.Run(tc.name, func(t *testing.T) { datums, err := tc.datums() @@ -207,7 +241,7 @@ func TestBasicDatums(t *testing.T) { numCols := len(datums[0]) maxRowGroupSize := int64(2) - fileName := "TestBasicDatums" + fileName := "TestBasicDatums.parquet" f, err := os.CreateTemp("", fileName) require.NoError(t, err) @@ -255,7 +289,7 @@ func TestInvalidWriterUsage(t *testing.T) { require.NoError(t, err) err = writer.AddRow([]tree.Datum{tree.NewDInt(0), datum}) - require.ErrorContains(t, err, "expected datum of type bool") + require.ErrorContains(t, err, "expected DBool") _ = writer.Close() }) @@ -266,7 +300,7 @@ func TestVersions(t *testing.T) { require.NoError(t, err) for version := range allowedVersions { - fileName := "TestVersions" + fileName := "TestVersions.parquet" f, err := os.CreateTemp("", fileName) require.NoError(t, err)