From 6b9b940246aa26204fe8434d57c6f0f6e10d59e7 Mon Sep 17 00:00:00 2001 From: Jayant Shrivastava Date: Tue, 28 Mar 2023 14:01:15 -0400 Subject: [PATCH] util/parquet: create parquet writer library This change implements a `Writer` struct in the new `util/parquet` package. This `Writer` takes datum rows and writes them to the `io.Writer` sink using a configurable parquet version (defaults to v2.6). The package implements several features internally required to write in the parquet format: - schema creation - row group / column page management - encoding/decoding of CRDB datums to parquet datums Currently, the writer only supports types found in the TPCC workload, namely INT, DECIMAL, STRING UUID, TIMESTAMP and BOOL. This change also adds a benchmark and tests which verify the correctness of the writer and test utils for reading datums from parquet files. Informs: https://github.com/cockroachdb/cockroach/issues/99028 Epic: None Release note: None --- pkg/BUILD.bazel | 4 + pkg/ccl/changefeedccl/BUILD.bazel | 1 - pkg/sql/sem/tree/datum.go | 10 + pkg/util/parquet/BUILD.bazel | 50 +++++ pkg/util/parquet/decoders.go | 104 ++++++++++ pkg/util/parquet/schema.go | 209 ++++++++++++++++++++ pkg/util/parquet/testutils.go | 152 +++++++++++++++ pkg/util/parquet/write_functions.go | 143 ++++++++++++++ pkg/util/parquet/writer.go | 176 +++++++++++++++++ pkg/util/parquet/writer_bench_test.go | 65 +++++++ pkg/util/parquet/writer_test.go | 268 ++++++++++++++++++++++++++ 11 files changed, 1181 insertions(+), 1 deletion(-) create mode 100644 pkg/util/parquet/BUILD.bazel create mode 100644 pkg/util/parquet/decoders.go create mode 100644 pkg/util/parquet/schema.go create mode 100644 pkg/util/parquet/testutils.go create mode 100644 pkg/util/parquet/write_functions.go create mode 100644 pkg/util/parquet/writer.go create mode 100644 pkg/util/parquet/writer_bench_test.go create mode 100644 pkg/util/parquet/writer_test.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 5f6fd4af88f9..76d718b6429f 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -628,6 +628,7 @@ ALL_TESTS = [ "//pkg/util/netutil/addr:addr_test", "//pkg/util/netutil:netutil_test", "//pkg/util/optional:optional_test", + "//pkg/util/parquet:parquet_test", "//pkg/util/pprofutil:pprofutil_test", "//pkg/util/pretty:pretty_test", "//pkg/util/protoutil:protoutil_test", @@ -2204,6 +2205,8 @@ GO_TARGETS = [ "//pkg/util/netutil:netutil_test", "//pkg/util/optional:optional", "//pkg/util/optional:optional_test", + "//pkg/util/parquet:parquet", + "//pkg/util/parquet:parquet_test", "//pkg/util/pprofutil:pprofutil", "//pkg/util/pprofutil:pprofutil_test", "//pkg/util/pretty:pretty", @@ -3292,6 +3295,7 @@ GET_X_DATA_TARGETS = [ "//pkg/util/netutil:get_x_data", "//pkg/util/netutil/addr:get_x_data", "//pkg/util/optional:get_x_data", + "//pkg/util/parquet:get_x_data", "//pkg/util/pprofutil:get_x_data", "//pkg/util/pretty:get_x_data", "//pkg/util/protoutil:get_x_data", diff --git a/pkg/ccl/changefeedccl/BUILD.bazel b/pkg/ccl/changefeedccl/BUILD.bazel index 242fe97bfd3c..480f7c7cdd91 100644 --- a/pkg/ccl/changefeedccl/BUILD.bazel +++ b/pkg/ccl/changefeedccl/BUILD.bazel @@ -147,7 +147,6 @@ go_library( "//pkg/util/timeutil", "//pkg/util/tracing", "//pkg/util/uuid", - "@com_github_apache_arrow_go_v11//parquet", "@com_github_cockroachdb_apd_v3//:apd", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_logtags//:logtags", diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go index 166b5f530aa4..ef1bb8c1eacc 100644 --- a/pkg/sql/sem/tree/datum.go +++ b/pkg/sql/sem/tree/datum.go @@ -1038,6 +1038,16 @@ func MustBeDDecimal(e Expr) DDecimal { panic(errors.AssertionFailedf("expected *DDecimal, found %T", e)) } +// AsDDecimal attempts to retrieve a DDecimal from an Expr, returning a DDecimal and +// a flag signifying whether the assertion was successful. +func AsDDecimal(e Expr) (*DDecimal, bool) { + switch t := e.(type) { + case *DDecimal: + return t, true + } + return nil, false +} + // ParseDDecimal parses and returns the *DDecimal Datum value represented by the // provided string, or an error if parsing is unsuccessful. func ParseDDecimal(s string) (*DDecimal, error) { diff --git a/pkg/util/parquet/BUILD.bazel b/pkg/util/parquet/BUILD.bazel new file mode 100644 index 000000000000..55e18c649c7f --- /dev/null +++ b/pkg/util/parquet/BUILD.bazel @@ -0,0 +1,50 @@ +load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "parquet", + srcs = [ + "decoders.go", + "schema.go", + "testutils.go", + "write_functions.go", + "writer.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/util/parquet", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/pgwire/pgcode", + "//pkg/sql/pgwire/pgerror", + "//pkg/sql/sem/tree", + "//pkg/sql/types", + "//pkg/util/uuid", + "@com_github_apache_arrow_go_v11//parquet", + "@com_github_apache_arrow_go_v11//parquet/file", + "@com_github_apache_arrow_go_v11//parquet/schema", + "@com_github_cockroachdb_errors//:errors", + "@com_github_lib_pq//oid", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) + +go_test( + name = "parquet_test", + srcs = [ + "writer_bench_test.go", + "writer_test.go", + ], + args = ["-test.timeout=295s"], + embed = [":parquet"], + deps = [ + "//pkg/sql/randgen", + "//pkg/sql/sem/tree", + "//pkg/sql/types", + "//pkg/util/timeutil", + "//pkg/util/uuid", + "@com_github_apache_arrow_go_v11//parquet/file", + "@com_github_stretchr_testify//require", + ], +) + +get_x_data(name = "get_x_data") diff --git a/pkg/util/parquet/decoders.go b/pkg/util/parquet/decoders.go new file mode 100644 index 000000000000..bafd70c8bea9 --- /dev/null +++ b/pkg/util/parquet/decoders.go @@ -0,0 +1,104 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package parquet + +import ( + "time" + + "github.com/apache/arrow/go/v11/parquet" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/errors" +) + +// decoder is used to store typedDecoders of various types in the same +// schema definition. +type decoder interface{} + +type typedDecoder[T parquetDatatypes] interface { + decoder + decode(v T) (tree.Datum, error) +} + +func decode[T parquetDatatypes](dec decoder, v T) (tree.Datum, error) { + td, ok := dec.(typedDecoder[T]) + if !ok { + return nil, errors.AssertionFailedf("expected typedDecoder[%T], but found %T", v, dec) + } + return td.decode(v) +} + +type boolDecoder struct{} + +func (boolDecoder) decode(v bool) (tree.Datum, error) { + return tree.MakeDBool(tree.DBool(v)), nil +} + +type stringDecoder struct{} + +func (stringDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + return tree.NewDString(string(v)), nil +} + +type int64Decoder struct{} + +func (int64Decoder) decode(v int64) (tree.Datum, error) { + return tree.NewDInt(tree.DInt(v)), nil +} + +type int32Decoder struct{} + +func (int32Decoder) decode(v int32) (tree.Datum, error) { + return tree.NewDInt(tree.DInt(v)), nil +} + +type decimalDecoder struct{} + +func (decimalDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + return tree.ParseDDecimal(string(v)) +} + +type timestampDecoder struct{} + +func (timestampDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + dtStr := string(v) + d, dependsOnCtx, err := tree.ParseDTimestamp(nil, dtStr, time.Microsecond) + if dependsOnCtx { + return nil, errors.New("TimestampTZ depends on context") + } + if err != nil { + return nil, err + } + // Converts the timezone from "loc(+0000)" to "UTC", which are equivalent. + d.Time = d.Time.UTC() + return d, nil +} + +type uUIDDecoder struct{} + +func (uUIDDecoder) decode(v parquet.FixedLenByteArray) (tree.Datum, error) { + uid, err := uuid.FromBytes(v) + if err != nil { + return nil, err + } + return tree.NewDUuid(tree.DUuid{UUID: uid}), nil +} + +// Defeat the linter's unused lint errors. +func init() { + var _, _ = boolDecoder{}.decode(false) + var _, _ = stringDecoder{}.decode(parquet.ByteArray{}) + var _, _ = int32Decoder{}.decode(0) + var _, _ = int64Decoder{}.decode(0) + var _, _ = decimalDecoder{}.decode(parquet.ByteArray{}) + var _, _ = timestampDecoder{}.decode(parquet.ByteArray{}) + var _, _ = uUIDDecoder{}.decode(parquet.FixedLenByteArray{}) +} diff --git a/pkg/util/parquet/schema.go b/pkg/util/parquet/schema.go new file mode 100644 index 000000000000..f8b79972a4c1 --- /dev/null +++ b/pkg/util/parquet/schema.go @@ -0,0 +1,209 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package parquet + +import ( + "math" + + "github.com/apache/arrow/go/v11/parquet" + "github.com/apache/arrow/go/v11/parquet/schema" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/errors" + "github.com/lib/pq/oid" +) + +// A schema field is an internal identifier for schema nodes used by the parquet library. +// A value of -1 will let the library auto-assign values. This does not affect reading +// or writing parquet files. +const defaultSchemaFieldID = int32(-1) + +// The parquet library utilizes a type length of -1 for all types +// except for the type parquet.FixedLenByteArray, in which case the type +// length is the length of the array. See comment on (*schema.PrimitiveNode).TypeLength() +const defaultTypeLength = -1 + +// A column stores the parquet schema node and encoder for a +type column struct { + node schema.Node + colWriter writeFn + decoder decoder + typ *types.T +} + +// A SchemaDefinition stores a parquet schema. +type SchemaDefinition struct { + // The index of a column when reading or writing parquet files + // will correspond to the column's index in this array. + cols []column + + // The schema is a root node with terminal children nodes which represent + // primitive types such as int or bool. The individual columns can be + // traversed using schema.Column(i). The children are indexed from [0, + // len(cols)). + schema *schema.Schema +} + +// NewSchema generates a SchemaDefinition. +// +// Columns in the returned SchemaDefinition will match +// the order they appear in the supplied iterator. +func NewSchema(columnNames []string, columnTypes []*types.T) (*SchemaDefinition, error) { + if len(columnTypes) != len(columnNames) { + return nil, errors.AssertionFailedf("the number of column names must match the number of column types") + } + + cols := make([]column, 0) + fields := make([]schema.Node, 0) + + for i := 0; i < len(columnNames); i++ { + parquetCol, err := makeColumn(columnNames[i], columnTypes[i]) + if err != nil { + return nil, err + } + cols = append(cols, parquetCol) + fields = append(fields, parquetCol.node) + } + + groupNode, err := schema.NewGroupNode("schema", parquet.Repetitions.Required, + fields, defaultSchemaFieldID) + if err != nil { + return nil, err + } + return &SchemaDefinition{ + cols: cols, + schema: schema.NewSchema(groupNode), + }, nil +} + +// makeColumn constructs a column. +func makeColumn(colName string, typ *types.T) (column, error) { + // Setting parquet.Repetitions.Optional makes parquet interpret all columns as nullable. + // When writing data, we will specify a definition level of 0 (null) or 1 (not null). + // See https://blog.twitter.com/engineering/en_us/a/2013/dremel-made-simple-with-parquet + // for more information regarding definition levels. + defaultRepetitions := parquet.Repetitions.Optional + + result := column{} + var err error + switch typ.Family() { + case types.BoolFamily: + result.node = schema.NewBooleanNode(colName, defaultRepetitions, defaultSchemaFieldID) + result.colWriter = writeBool + result.decoder = boolDecoder{} + result.typ = types.Bool + return result, nil + case types.StringFamily: + result.node, err = schema.NewPrimitiveNodeLogical(colName, + defaultRepetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + + if err != nil { + return result, err + } + result.colWriter = writeString + result.decoder = stringDecoder{} + result.typ = types.String + return result, nil + case types.IntFamily: + // Note: integer datums are always signed: https://www.cockroachlabs.com/docs/stable/int.html + if typ.Oid() == oid.T_int8 { + result.node, err = schema.NewPrimitiveNodeLogical(colName, + defaultRepetitions, schema.NewIntLogicalType(64, true), + parquet.Types.Int64, defaultTypeLength, + defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = writeInt64 + result.decoder = int64Decoder{} + result.typ = types.Int + return result, nil + } + + result.node = schema.NewInt32Node(colName, defaultRepetitions, defaultSchemaFieldID) + result.colWriter = writeInt32 + result.decoder = int32Decoder{} + result.typ = types.Int4 + return result, nil + case types.DecimalFamily: + // According to PostgresSQL docs, scale or precision of 0 implies max + // precision and scale. This code assumes that CRDB matches this behavior. + // https://www.postgresql.org/docs/10/datatype-numeric.html + precision := typ.Precision() + scale := typ.Scale() + if typ.Precision() == 0 { + precision = math.MaxInt32 + } + if typ.Scale() == 0 { + // Scale cannot exceed precision, so we do not set it to math.MaxInt32. + // This is relevant for cases when the precision is nonzero, but the scale is 0. + scale = precision + } + + result.node, err = schema.NewPrimitiveNodeLogical(colName, + defaultRepetitions, schema.NewDecimalLogicalType(precision, + scale), parquet.Types.ByteArray, defaultTypeLength, + defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = writeDecimal + result.decoder = decimalDecoder{} + result.typ = types.Decimal + return result, nil + case types.UuidFamily: + result.node, err = schema.NewPrimitiveNodeLogical(colName, + defaultRepetitions, schema.UUIDLogicalType{}, + parquet.Types.FixedLenByteArray, uuid.Size, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = writeUUID + result.decoder = uUIDDecoder{} + result.typ = types.Uuid + return result, nil + case types.TimestampFamily: + // Note that all timestamp datums are in UTC: https://www.cockroachlabs.com/docs/stable/timestamp.html + result.node, err = schema.NewPrimitiveNodeLogical(colName, + defaultRepetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + + result.colWriter = writeTimestamp + result.decoder = timestampDecoder{} + result.typ = types.Timestamp + return result, nil + + // TODO(#99028): implement support for the remaining types. + // case types.INetFamily: + // case types.JsonFamily: + // case types.FloatFamily: + // case types.BytesFamily: + // case types.BitFamily: + // case types.EnumFamily: + // case types.Box2DFamily: + // case types.GeographyFamily: + // case types.GeometryFamily: + // case types.DateFamily: + // case types.TimeFamily: + // case types.TimeTZFamily: + // case types.IntervalFamily: + // case types.TimestampTZFamily: + // case types.ArrayFamily: + default: + return result, pgerror.Newf(pgcode.FeatureNotSupported, "parquet export does not support the %v type yet", typ.Family()) + } +} diff --git a/pkg/util/parquet/testutils.go b/pkg/util/parquet/testutils.go new file mode 100644 index 000000000000..31b6629a833b --- /dev/null +++ b/pkg/util/parquet/testutils.go @@ -0,0 +1,152 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package parquet + +import ( + "math" + "os" + "testing" + + "github.com/apache/arrow/go/v11/parquet" + "github.com/apache/arrow/go/v11/parquet/file" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ReadFileAndVerifyDatums reads the parquetFile and first asserts its metadata +// matches numRows, numCols, and other options specified by the config. Then, it reads +// the file and asserts that it's data matches writtenDatums. +func ReadFileAndVerifyDatums( + t *testing.T, + parquetFile string, + numRows int, + numCols int, + cfg *Config, + sch *SchemaDefinition, + writtenDatums [][]tree.Datum, +) { + f, err := os.Open(parquetFile) + require.NoError(t, err) + + reader, err := file.NewParquetReader(f) + require.NoError(t, err) + + assert.Equal(t, reader.NumRows(), int64(numRows)) + assert.Equal(t, reader.MetaData().Schema.NumColumns(), numCols) + + numRowGroups := int(math.Ceil(float64(numRows) / float64(cfg.maxRowGroupLength))) + assert.EqualValues(t, numRowGroups, reader.NumRowGroups()) + + readDatums := make([][]tree.Datum, numRows) + for i := 0; i < numRows; i++ { + readDatums[i] = make([]tree.Datum, numCols) + } + + startingRowIdx := 0 + for rg := 0; rg < reader.NumRowGroups(); rg++ { + rgr := reader.RowGroup(rg) + rowsInRowGroup := rgr.NumRows() + defLevels := make([]int16, rowsInRowGroup) + + for colIdx := 0; colIdx < numCols; colIdx++ { + col, err := rgr.Column(colIdx) + require.NoError(t, err) + + dec := sch.cols[colIdx].decoder + + switch col.Type() { + case parquet.Types.Boolean: + values := make([]bool, rowsInRowGroup) + readBatchHelper(t, col, rowsInRowGroup, values, defLevels) + decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + case parquet.Types.Int32: + values := make([]int32, numRows) + readBatchHelper(t, col, rowsInRowGroup, values, defLevels) + decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + case parquet.Types.Int64: + values := make([]int64, rowsInRowGroup) + readBatchHelper(t, col, rowsInRowGroup, values, defLevels) + decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + case parquet.Types.Int96: + panic("unimplemented") + case parquet.Types.Float: + panic("unimplemented") + case parquet.Types.Double: + panic("unimplemented") + case parquet.Types.ByteArray: + values := make([]parquet.ByteArray, rowsInRowGroup) + readBatchHelper(t, col, rowsInRowGroup, values, defLevels) + decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + case parquet.Types.FixedLenByteArray: + values := make([]parquet.FixedLenByteArray, rowsInRowGroup) + readBatchHelper(t, col, rowsInRowGroup, values, defLevels) + decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels) + } + } + startingRowIdx += int(rowsInRowGroup) + } + require.NoError(t, reader.Close()) + + for i := 0; i < numRows; i++ { + for j := 0; j < numCols; j++ { + assert.Equal(t, writtenDatums[i][j], readDatums[i][j]) + } + } +} + +func readBatchHelper[T parquetDatatypes]( + t *testing.T, r file.ColumnChunkReader, expectedrowsInRowGroup int64, values []T, defLvls []int16, +) { + numRead, err := readBatch(r, expectedrowsInRowGroup, values, defLvls) + require.NoError(t, err) + require.Equal(t, numRead, expectedrowsInRowGroup) +} + +type batchReader[T parquetDatatypes] interface { + ReadBatch(batchSize int64, values []T, defLvls []int16, repLvls []int16) (total int64, valuesRead int, err error) +} + +func readBatch[T parquetDatatypes]( + r file.ColumnChunkReader, batchSize int64, values []T, defLvls []int16, +) (int64, error) { + br, ok := r.(batchReader[T]) + if !ok { + return 0, errors.AssertionFailedf("expected batchReader of type %T, but found %T instead", values, r) + } + numRowsRead, _, err := br.ReadBatch(batchSize, values, defLvls, nil) + return numRowsRead, err +} + +func decodeValuesIntoDatumsHelper[T parquetDatatypes]( + t *testing.T, + datums [][]tree.Datum, + colIdx int, + startingRowIdx int, + dec decoder, + values []T, + defLevels []int16, +) { + var err error + // If the defLevel of a datum is 0, parquet will not write it to the column. + // Use valueReadIdx to only read from the front of the values array, where datums are defined. + valueReadIdx := 0 + for rowOffset, defLevel := range defLevels { + d := tree.DNull + if defLevel != 0 { + d, err = decode(dec, values[valueReadIdx]) + require.NoError(t, err) + valueReadIdx++ + } + datums[startingRowIdx+rowOffset][colIdx] = d + } +} diff --git a/pkg/util/parquet/write_functions.go b/pkg/util/parquet/write_functions.go new file mode 100644 index 000000000000..853073be06e1 --- /dev/null +++ b/pkg/util/parquet/write_functions.go @@ -0,0 +1,143 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package parquet + +import ( + "github.com/apache/arrow/go/v11/parquet" + "github.com/apache/arrow/go/v11/parquet/file" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/errors" +) + +// nonNilDefLevel represents a definition level of 1, meaning that the value is non-nil. +// nilDefLevel represents a definition level of 0, meaning that the value is nil. +// +// For more info on definition levels, refer to +// https://github.com/apache/parquet-format/blob/master/README.md#nested-encoding. +var nonNilDefLevel = []int16{1} +var nilDefLevel = []int16{0} + +// A writeFn encodes a datum and writes using the provided column chunk writer. +type writeFn func(d tree.Datum, w file.ColumnChunkWriter) error + +func writeInt32(d tree.Datum, w file.ColumnChunkWriter) error { + if d == tree.DNull { + return writeNilBatch[int32](w) + } + di, ok := tree.AsDInt(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DInt, found %T", d) + } + return writeBatch[int32](w, int32(di)) +} + +func writeInt64(d tree.Datum, w file.ColumnChunkWriter) error { + if d == tree.DNull { + return writeNilBatch[int64](w) + } + di, ok := tree.AsDInt(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DInt, found %T", d) + } + return writeBatch[int64](w, int64(di)) +} + +func writeBool(d tree.Datum, w file.ColumnChunkWriter) error { + if d == tree.DNull { + return writeNilBatch[bool](w) + } + di, ok := tree.AsDBool(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DBool, found %T", d) + } + return writeBatch[bool](w, bool(di)) +} + +func writeString(d tree.Datum, w file.ColumnChunkWriter) error { + if d == tree.DNull { + return writeNilBatch[parquet.ByteArray](w) + } + di, ok := tree.AsDString(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DString, found %T", d) + } + return writeBatch[parquet.ByteArray](w, parquet.ByteArray(di)) +} + +func writeTimestamp(d tree.Datum, w file.ColumnChunkWriter) error { + if d == tree.DNull { + return writeNilBatch[parquet.ByteArray](w) + } + + _, ok := tree.AsDTimestamp(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DTimestamp, found %T", d) + } + + fmtCtx := tree.NewFmtCtx(tree.FmtBareStrings) + d.Format(fmtCtx) + + return writeBatch[parquet.ByteArray](w, parquet.ByteArray(fmtCtx.CloseAndGetString())) +} + +func writeUUID(d tree.Datum, w file.ColumnChunkWriter) error { + if d == tree.DNull { + return writeNilBatch[parquet.FixedLenByteArray](w) + } + + di, ok := tree.AsDUuid(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DUuid, found %T", d) + } + return writeBatch[parquet.FixedLenByteArray](w, di.UUID.GetBytes()) +} + +func writeDecimal(d tree.Datum, w file.ColumnChunkWriter) error { + if d == tree.DNull { + return writeNilBatch[parquet.ByteArray](w) + } + di, ok := tree.AsDDecimal(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DDecimal, found %T", d) + } + return writeBatch[parquet.ByteArray](w, parquet.ByteArray(di.String())) +} + +// parquetDatatypes are the physical types used in the parquet library. +type parquetDatatypes interface { + bool | int32 | int64 | parquet.ByteArray | parquet.FixedLenByteArray +} + +// batchWriter is an interface representing parquet column chunk writers such as +// file.Int64ColumnChunkWriter and file.BooleanColumnChunkWriter. +type batchWriter[T parquetDatatypes] interface { + WriteBatch(values []T, defLevels, repLevels []int16) (valueOffset int64, err error) +} + +func writeBatch[T parquetDatatypes](w file.ColumnChunkWriter, v T) (err error) { + bw, ok := w.(batchWriter[T]) + if !ok { + return errors.AssertionFailedf("expected batchWriter of type %T, but found %T instead", []T(nil), w) + } + _, err = bw.WriteBatch([]T{v}, nonNilDefLevel, nil) + return err +} + +func writeNilBatch[T parquetDatatypes](w file.ColumnChunkWriter) (err error) { + bw, ok := w.(batchWriter[T]) + if !ok { + return errors.AssertionFailedf("expected batchWriter of type %T, but found %T instead", []T(nil), w) + } + _, err = bw.WriteBatch([]T(nil), nilDefLevel, nil) + return err +} diff --git a/pkg/util/parquet/writer.go b/pkg/util/parquet/writer.go new file mode 100644 index 000000000000..b0bf3ac931f9 --- /dev/null +++ b/pkg/util/parquet/writer.go @@ -0,0 +1,176 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package parquet + +import ( + "io" + + "github.com/apache/arrow/go/v11/parquet" + "github.com/apache/arrow/go/v11/parquet/file" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/errors" +) + +// Config stores configurable options for the Writer. +type Config struct { + maxRowGroupLength int64 + version parquet.Version +} + +func newConfig() *Config { + return &Config{ + maxRowGroupLength: parquet.DefaultMaxRowGroupLen, + version: parquet.V2_6, + } +} + +type option interface { + apply(c *Config) error +} + +// WithMaxRowGroupLength specifies the maximum number of rows to include +// in a row group when writing data. +type WithMaxRowGroupLength int64 + +func (l WithMaxRowGroupLength) apply(c *Config) error { + if l <= 0 { + return errors.AssertionFailedf("max group length must be greater than 0") + } + + c.maxRowGroupLength = int64(l) + return nil +} + +// WithVersion specifies the parquet version to use when writing data. +// Valid options are "v1.0", "v2.4", and "v2.6". +type WithVersion string + +func (v WithVersion) apply(c *Config) error { + if _, ok := allowedVersions[string(v)]; !ok { + return errors.AssertionFailedf("invalid version string") + } + + c.version = allowedVersions[string(v)] + return nil +} + +var allowedVersions = map[string]parquet.Version{ + "v1.0": parquet.V1_0, + "v2.4": parquet.V1_0, + "v2.6": parquet.V2_6, +} + +// A Writer writes datums into an io.Writer sink. The Writer should be Close()ed +// before attempting to read from the output sink so parquet metadata is +// available. +type Writer struct { + sch *SchemaDefinition + writer *file.Writer + cfg *Config + + currentRowGroupSize int64 + currentRowGroupWriter file.BufferedRowGroupWriter +} + +// NewWriter constructs a new Writer which outputs to +// the given sink. +// +// TODO(#99028): maxRowGroupSize should be a configuration option, along with +// compression schemes, allocator, batch size, page size etc +func NewWriter(sch *SchemaDefinition, sink io.Writer, opts ...option) (*Writer, error) { + cfg := newConfig() + for _, opt := range opts { + err := opt.apply(cfg) + if err != nil { + return nil, err + } + } + + parquetOpts := []parquet.WriterProperty{parquet.WithCreatedBy("cockroachdb"), + parquet.WithVersion(cfg.version)} + props := parquet.NewWriterProperties(parquetOpts...) + writer := file.NewParquetWriter(sink, sch.schema.Root(), file.WithWriterProps(props)) + + return &Writer{ + sch: sch, + writer: writer, + cfg: cfg, + }, nil +} + +func (w *Writer) writeDatumToColChunk(d tree.Datum, colIdx int) error { + if colIdx >= len(w.sch.cols) { + return errors.AssertionFailedf("column index %d out of bounds for"+ + " array of size %d", colIdx, len(w.sch.cols)) + } + // Note that EquivalentOrNull only allows null equivalence if the receiver is null. + if !d.ResolvedType().EquivalentOrNull(w.sch.cols[colIdx].typ, false) { + return errors.AssertionFailedf("expected datum of type %s, but found datum"+ + " of type: %s", w.sch.cols[colIdx].typ.Name(), d.ResolvedType().Name()) + } + + cw, err := w.currentRowGroupWriter.Column(colIdx) + if err != nil { + return err + } + + err = w.sch.cols[colIdx].colWriter(d, cw) + if err != nil { + return err + } + return nil +} + +// SchemaDefinition returns the SchemaDefinition for this writer. +func (w *Writer) SchemaDefinition() *SchemaDefinition { + return w.sch +} + +// Config returns the Config for this writer. +func (w *Writer) Config() *Config { + return w.cfg +} + +// AddData writes a row. There is no guarantee that the row will +// immediately be flushed to the output sink. +// +// Datums should be in the same order as specified in the +// SchemaDefinition of the Writer. +func (w *Writer) AddData(datums []tree.Datum) error { + if w.currentRowGroupWriter == nil { + w.currentRowGroupWriter = w.writer.AppendBufferedRowGroup() + } else if w.currentRowGroupSize == w.cfg.maxRowGroupLength { + if err := w.currentRowGroupWriter.Close(); err != nil { + return err + } + w.currentRowGroupWriter = w.writer.AppendBufferedRowGroup() + w.currentRowGroupSize = 0 + } + + for cIdx := 0; cIdx < len(datums); cIdx++ { + if err := w.writeDatumToColChunk(datums[cIdx], cIdx); err != nil { + return err + } + } + w.currentRowGroupSize += 1 + return nil +} + +// Close closes the writer and flushes any buffered data to the sink. +// If the sink implements io.WriteCloser, it will be closed by this method. +func (w *Writer) Close() error { + if w.currentRowGroupWriter != nil { + if err := w.currentRowGroupWriter.Close(); err != nil { + return err + } + } + return w.writer.Close() +} diff --git a/pkg/util/parquet/writer_bench_test.go b/pkg/util/parquet/writer_bench_test.go new file mode 100644 index 000000000000..080f6bf3493a --- /dev/null +++ b/pkg/util/parquet/writer_bench_test.go @@ -0,0 +1,65 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package parquet + +import ( + "fmt" + "math/rand" + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/stretchr/testify/require" +) + +// BenchmarkParquetWriter benchmarks the Writer.AddData operation. +func BenchmarkParquetWriter(b *testing.B) { + rng := rand.New(rand.NewSource(timeutil.Now().UnixNano())) + + // Create a row size of 2KiB. + numCols := 16 + datumSizeBytes := 128 + sch := newColSchema(numCols) + for i := 0; i < numCols; i++ { + sch.columnTypes[i] = types.String + sch.columnNames[i] = fmt.Sprintf("col%d", i) + } + datums := make([]tree.Datum, numCols) + for i := 0; i < numCols; i++ { + p := make([]byte, datumSizeBytes) + _, _ = rng.Read(p) + tree.NewDBytes(tree.DBytes(p)) + datums[i] = tree.NewDString(string(p)) + } + + fileName := "BenchmarkParquetWriter" + f, err := os.CreateTemp("", fileName) + require.NoError(b, err) + + schemaDef, err := NewSchema(sch.columnNames, sch.columnTypes) + require.NoError(b, err) + + writer, err := NewWriter(schemaDef, f) + require.NoError(b, err) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + err := writer.AddData(datums) + require.NoError(b, err) + } + + err = writer.Close() + require.NoError(b, err) +} diff --git a/pkg/util/parquet/writer_test.go b/pkg/util/parquet/writer_test.go new file mode 100644 index 000000000000..5f27ff3a95d3 --- /dev/null +++ b/pkg/util/parquet/writer_test.go @@ -0,0 +1,268 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package parquet + +import ( + "fmt" + "math/rand" + "os" + "testing" + "time" + + "github.com/apache/arrow/go/v11/parquet/file" + "github.com/cockroachdb/cockroach/pkg/sql/randgen" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/stretchr/testify/require" +) + +type colSchema struct { + columnNames []string + columnTypes []*types.T +} + +func newColSchema(numCols int) *colSchema { + return &colSchema{ + columnNames: make([]string, numCols), + columnTypes: make([]*types.T, numCols), + } +} + +var supportedTypes = []*types.T{ + types.Int, + types.Bool, + types.String, + types.Decimal, + types.Uuid, + types.Timestamp, +} + +func makeRandDatums(numRows int, sch *colSchema, rng *rand.Rand) [][]tree.Datum { + datums := make([][]tree.Datum, numRows) + for i := 0; i < numRows; i++ { + datums[i] = make([]tree.Datum, len(sch.columnTypes)) + for j := 0; j < len(sch.columnTypes); j++ { + datums[i][j] = randgen.RandDatum(rng, sch.columnTypes[j], true) + } + } + return datums +} + +func makeRandSchema(numCols int, rng *rand.Rand) *colSchema { + sch := newColSchema(numCols) + for i := 0; i < numCols; i++ { + sch.columnTypes[i] = supportedTypes[rng.Intn(len(supportedTypes))] + sch.columnNames[i] = fmt.Sprintf("%s%d", sch.columnTypes[i].Name(), i) + } + return sch +} + +func TestRandomDatums(t *testing.T) { + seed := rand.NewSource(timeutil.Now().UnixNano()) + rng := rand.New(seed) + t.Logf("random seed %d", seed.Int63()) + + numRows := 25 + numCols := 10 + maxRowGroupSize := int64(4) + + sch := makeRandSchema(numCols, rng) + datums := makeRandDatums(numRows, sch, rng) + + fileName := "TestRandomDatums" + f, err := os.CreateTemp("", fileName) + require.NoError(t, err) + + schemaDef, err := NewSchema(sch.columnNames, sch.columnTypes) + require.NoError(t, err) + + writer, err := NewWriter(schemaDef, f, WithMaxRowGroupLength(maxRowGroupSize)) + require.NoError(t, err) + + for _, row := range datums { + err := writer.AddData(row) + require.NoError(t, err) + } + + err = writer.Close() + require.NoError(t, err) + + ReadFileAndVerifyDatums(t, f.Name(), numRows, numCols, writer.Config(), writer.SchemaDefinition(), datums) +} + +func TestBasicDatums(t *testing.T) { + for _, tc := range []struct { + name string + sch *colSchema + datums func() ([][]tree.Datum, error) + }{ + { + name: "bool", + sch: &colSchema{ + columnTypes: []*types.T{types.Bool, types.Bool, types.Bool}, + columnNames: []string{"a", "b", "c"}, + }, + datums: func() ([][]tree.Datum, error) { + return [][]tree.Datum{ + {tree.DBoolFalse, tree.DBoolTrue, tree.DNull}, + }, nil + }, + }, + { + name: "string", + sch: &colSchema{ + columnTypes: []*types.T{types.String, types.String, types.String}, + columnNames: []string{"a", "b", "c"}, + }, + datums: func() ([][]tree.Datum, error) { + return [][]tree.Datum{ + {tree.NewDString("a"), tree.NewDString(""), tree.DNull}}, nil + }, + }, + { + name: "timestamp", + sch: &colSchema{ + columnTypes: []*types.T{types.Timestamp, types.Timestamp, types.Timestamp}, + columnNames: []string{"a", "b", "c"}, + }, + datums: func() ([][]tree.Datum, error) { + return [][]tree.Datum{ + { + tree.MustMakeDTimestamp(timeutil.Now(), time.Microsecond), + tree.MustMakeDTimestamp(timeutil.Now(), time.Microsecond), + tree.DNull, + }, + }, nil + }, + }, + { + name: "int", + sch: &colSchema{ + columnTypes: []*types.T{types.Int4, types.Int, types.Int, types.Int2, types.Int2}, + columnNames: []string{"a", "b", "c", "d", "e"}, + }, + datums: func() ([][]tree.Datum, error) { + return [][]tree.Datum{ + {tree.NewDInt(1 << 16), tree.NewDInt(1 << 32), + tree.NewDInt(-1 * (1 << 32)), tree.NewDInt(12), tree.DNull}, + }, nil + }, + }, + { + name: "decimal", + sch: &colSchema{ + columnTypes: []*types.T{types.Decimal, types.Decimal, types.Decimal, types.Decimal}, + columnNames: []string{"a", "b", "c", "d"}, + }, + datums: func() ([][]tree.Datum, error) { + var err error + datums := make([]tree.Datum, 4) + if datums[0], err = tree.ParseDDecimal("-1.222"); err != nil { + return nil, err + } + if datums[1], err = tree.ParseDDecimal("-inf"); err != nil { + return nil, err + } + if datums[2], err = tree.ParseDDecimal("inf"); err != nil { + return nil, err + } + if datums[3], err = tree.ParseDDecimal("nan"); err != nil { + return nil, err + } + return [][]tree.Datum{datums}, nil + }, + }, + { + name: "uuid", + sch: &colSchema{ + columnTypes: []*types.T{types.Uuid, types.Uuid}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + uid, err := uuid.FromString("acde070d-8c4c-4f0d-9d8a-162843c10333") + if err != nil { + return nil, err + } + return [][]tree.Datum{ + {tree.NewDUuid(tree.DUuid{UUID: uid}), tree.DNull}, + }, nil + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + datums, err := tc.datums() + require.NoError(t, err) + numRows := len(datums) + numCols := len(datums[0]) + maxRowGroupSize := int64(2) + + fileName := "TestBasicDatums" + f, err := os.CreateTemp("", fileName) + require.NoError(t, err) + + schemaDef, err := NewSchema(tc.sch.columnNames, tc.sch.columnTypes) + require.NoError(t, err) + + writer, err := NewWriter(schemaDef, f, WithMaxRowGroupLength(maxRowGroupSize)) + require.NoError(t, err) + + for _, row := range datums { + err := writer.AddData(row) + require.NoError(t, err) + } + + err = writer.Close() + require.NoError(t, err) + + ReadFileAndVerifyDatums(t, f.Name(), numRows, numCols, writer.Config(), writer.SchemaDefinition(), datums) + }) + } +} + +func TestVersions(t *testing.T) { + schemaDef, err := NewSchema([]string{}, []*types.T{}) + require.NoError(t, err) + + for version := range allowedVersions { + fileName := "TestVersions" + f, err := os.CreateTemp("", fileName) + require.NoError(t, err) + + writer, err := NewWriter(schemaDef, f, WithVersion(version)) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + f, err = os.Open(f.Name()) + require.NoError(t, err) + + reader, err := file.NewParquetReader(f) + require.NoError(t, err) + + require.Equal(t, reader.MetaData().Version(), writer.Config().version) + + err = reader.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + } + + fileName := "TestVersions" + f, err := os.CreateTemp("", fileName) + require.NoError(t, err) + + _, err = NewWriter(schemaDef, f, WithVersion("invalid")) + require.Error(t, err) +}