diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go index 4ab3a74dab5f..bb3e895c24a9 100644 --- a/pkg/sql/sem/tree/datum.go +++ b/pkg/sql/sem/tree/datum.go @@ -847,6 +847,20 @@ func MustBeDFloat(e Expr) DFloat { panic(errors.AssertionFailedf("expected *DFloat, found %T", e)) } +// AsDFloat attempts to retrieve a DFloat from an Expr, returning a DFloat and +// a flag signifying whether the assertion was successful. The function should +// be used instead of direct type assertions wherever a *DFloat wrapped by a +// *DOidWrapper is possible. +func AsDFloat(e Expr) (*DFloat, bool) { + switch t := e.(type) { + case *DFloat: + return t, true + case *DOidWrapper: + return AsDFloat(t.Wrapped) + } + return nil, false +} + // NewDFloat is a helper routine to create a *DFloat initialized from its // argument. func NewDFloat(d DFloat) *DFloat { @@ -1407,6 +1421,20 @@ func NewDCollatedString( return &d, nil } +// AsDCollatedString attempts to retrieve a DString from an Expr, returning a AsDCollatedString and +// a flag signifying whether the assertion was successful. The function should +// be used instead of direct type assertions wherever a *DCollatedString wrapped by a +// *DOidWrapper is possible. +func AsDCollatedString(e Expr) (DCollatedString, bool) { + switch t := e.(type) { + case *DCollatedString: + return *t, true + case *DOidWrapper: + return AsDCollatedString(t.Wrapped) + } + return DCollatedString{}, false +} + // AmbiguousFormat implements the Datum interface. func (*DCollatedString) AmbiguousFormat() bool { return false } @@ -2286,6 +2314,20 @@ func MakeDTime(t timeofday.TimeOfDay) *DTime { return &d } +// AsDTime attempts to retrieve a DTime from an Expr, returning a DTimestamp and +// a flag signifying whether the assertion was successful. The function should +// be used instead of direct type assertions wherever a *DTime wrapped by a +// *DOidWrapper is possible. +func AsDTime(e Expr) (DTime, bool) { + switch t := e.(type) { + case *DTime: + return *t, true + case *DOidWrapper: + return AsDTime(t.Wrapped) + } + return DTime(timeofday.FromInt(0)), false +} + // ParseDTime parses and returns the *DTime Datum value represented by the // provided string, or an error if parsing is unsuccessful. // @@ -2434,6 +2476,20 @@ func NewDTimeTZFromLocation(t timeofday.TimeOfDay, loc *time.Location) *DTimeTZ return &DTimeTZ{timetz.MakeTimeTZFromLocation(t, loc)} } +// AsDTimeTZ attempts to retrieve a DTimeTZ from an Expr, returning a DTimeTZ and +// a flag signifying whether the assertion was successful. The function should +// be used instead of direct type assertions wherever a *DTimeTZ wrapped by a +// *DOidWrapper is possible. +func AsDTimeTZ(e Expr) (DTimeTZ, bool) { + switch t := e.(type) { + case *DTimeTZ: + return *t, true + case *DOidWrapper: + return AsDTimeTZ(t.Wrapped) + } + return DTimeTZ{}, false +} + // ParseDTimeTZ parses and returns the *DTime Datum value represented by the // provided string, or an error if parsing is unsuccessful. // @@ -3069,12 +3125,16 @@ type DInterval struct { duration.Duration } -// AsDInterval attempts to retrieve a DInterval from an Expr, panicking if the -// assertion fails. +// AsDInterval attempts to retrieve a DInterval from an Expr, returning a DInterval and +// a flag signifying whether the assertion was successful. The function should +// be used instead of direct type assertions wherever a *DInterval wrapped by a +// *DOidWrapper is possible. func AsDInterval(e Expr) (*DInterval, bool) { switch t := e.(type) { case *DInterval: return t, true + case *DOidWrapper: + return AsDInterval(t.Wrapped) } return nil, false } @@ -5017,6 +5077,20 @@ func NewDEnum(e DEnum) *DEnum { return &e } +// AsDEnum attempts to retrieve a DEnum from an Expr, returning a DEnum and +// a flag signifying whether the assertion was successful. The function should +// // be used instead of direct type assertions wherever a *DEnum wrapped by a +// // *DOidWrapper is possible. +func AsDEnum(e Expr) (*DEnum, bool) { + switch t := e.(type) { + case *DEnum: + return t, true + case *DOidWrapper: + return AsDEnum(t.Wrapped) + } + return nil, false +} + // MakeDEnumFromPhysicalRepresentation creates a DEnum of the input type // and the input physical representation. func MakeDEnumFromPhysicalRepresentation(typ *types.T, rep []byte) (DEnum, error) { diff --git a/pkg/util/parquet/BUILD.bazel b/pkg/util/parquet/BUILD.bazel index 9457cf1c2c26..482c7f86c7fc 100644 --- a/pkg/util/parquet/BUILD.bazel +++ b/pkg/util/parquet/BUILD.bazel @@ -13,11 +13,18 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/util/parquet", visibility = ["//visibility:public"], deps = [ + "//pkg/geo", + "//pkg/geo/geopb", "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", + "//pkg/sql/sem/catid", "//pkg/sql/sem/tree", "//pkg/sql/types", "//pkg/util", + "//pkg/util/bitarray", + "//pkg/util/duration", + "//pkg/util/encoding", + "//pkg/util/timeofday", "//pkg/util/uuid", "@com_github_apache_arrow_go_v11//parquet", "@com_github_apache_arrow_go_v11//parquet/file", @@ -38,10 +45,15 @@ go_test( args = ["-test.timeout=295s"], embed = [":parquet"], deps = [ + "//pkg/geo", "//pkg/sql/randgen", "//pkg/sql/sem/tree", "//pkg/sql/types", + "//pkg/util/bitarray", + "//pkg/util/duration", + "//pkg/util/ipaddr", "//pkg/util/timeutil", + "//pkg/util/timeutil/pgdate", "//pkg/util/uuid", "@com_github_apache_arrow_go_v11//parquet/file", "@com_github_stretchr_testify//require", diff --git a/pkg/util/parquet/decoders.go b/pkg/util/parquet/decoders.go index bafd70c8bea9..e3b4e17e8e0e 100644 --- a/pkg/util/parquet/decoders.go +++ b/pkg/util/parquet/decoders.go @@ -14,9 +14,15 @@ import ( "time" "github.com/apache/arrow/go/v11/parquet" + "github.com/cockroachdb/cockroach/pkg/geo" + "github.com/cockroachdb/cockroach/pkg/geo/geopb" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/util/bitarray" + "github.com/cockroachdb/cockroach/pkg/util/duration" + "github.com/cockroachdb/cockroach/pkg/util/timeofday" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" + "github.com/lib/pq/oid" ) // decoder is used to store typedDecoders of various types in the same @@ -72,7 +78,23 @@ func (timestampDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { dtStr := string(v) d, dependsOnCtx, err := tree.ParseDTimestamp(nil, dtStr, time.Microsecond) if dependsOnCtx { - return nil, errors.New("TimestampTZ depends on context") + return nil, errors.Newf("decoding timestamp %s depends on context", v) + } + if err != nil { + return nil, err + } + // Converts the timezone from "loc(+0000)" to "UTC", which are equivalent. + d.Time = d.Time.UTC() + return d, nil +} + +type timestampTZDecoder struct{} + +func (timestampTZDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + dtStr := string(v) + d, dependsOnCtx, err := tree.ParseDTimestampTZ(nil, dtStr, time.Microsecond) + if dependsOnCtx { + return nil, errors.Newf("decoding timestampTZ %s depends on context", v) } if err != nil { return nil, err @@ -92,6 +114,128 @@ func (uUIDDecoder) decode(v parquet.FixedLenByteArray) (tree.Datum, error) { return tree.NewDUuid(tree.DUuid{UUID: uid}), nil } +type iNetDecoder struct{} + +func (iNetDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + return tree.ParseDIPAddrFromINetString(string(v)) +} + +type jsonDecoder struct{} + +func (jsonDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + return tree.ParseDJSON(string(v)) +} + +type bitDecoder struct{} + +func (bitDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + ba, err := bitarray.Parse(string(v)) + if err != nil { + return nil, err + } + return &tree.DBitArray{BitArray: ba}, err +} + +type bytesDecoder struct{} + +func (bytesDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + return tree.NewDBytes(tree.DBytes(v)), nil +} + +type enumDecoder struct{} + +func (ed enumDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + return &tree.DEnum{ + LogicalRep: string(v), + }, nil +} + +type dateDecoder struct{} + +func (dateDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + d, dependCtx, err := tree.ParseDDate(nil, string(v)) + if dependCtx { + return nil, errors.Newf("decoding date %s depends on context", v) + } + return d, err +} + +type box2DDecoder struct{} + +func (box2DDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + b, err := geo.ParseCartesianBoundingBox(string(v)) + if err != nil { + return nil, err + } + return tree.NewDBox2D(b), nil +} + +type geographyDecoder struct{} + +func (geographyDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + g, err := geo.ParseGeographyFromEWKB(geopb.EWKB(v)) + if err != nil { + return nil, err + } + return &tree.DGeography{Geography: g}, nil +} + +type geometryDecoder struct{} + +func (geometryDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + g, err := geo.ParseGeometryFromEWKB(geopb.EWKB(v)) + if err != nil { + return nil, err + } + return &tree.DGeometry{Geometry: g}, nil +} + +type intervalDecoder struct{} + +func (intervalDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + return tree.ParseDInterval(duration.IntervalStyle_ISO_8601, string(v)) +} + +type timeDecoder struct{} + +func (timeDecoder) decode(v int64) (tree.Datum, error) { + return tree.MakeDTime(timeofday.TimeOfDay(v)), nil +} + +type timeTZDecoder struct{} + +func (timeTZDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + d, dependsOnCtx, err := tree.ParseDTimeTZ(nil, string(v), time.Microsecond) + if dependsOnCtx { + return nil, errors.Newf("parsed timeTZ %s depends on context", v) + } + return d, err +} + +type float32Decoder struct{} + +func (float32Decoder) decode(v float32) (tree.Datum, error) { + return tree.NewDFloat(tree.DFloat(v)), nil +} + +type float64Decoder struct{} + +func (float64Decoder) decode(v float64) (tree.Datum, error) { + return tree.NewDFloat(tree.DFloat(v)), nil +} + +type oidDecoder struct{} + +func (oidDecoder) decode(v int32) (tree.Datum, error) { + return tree.NewDOid(oid.Oid(v)), nil +} + +type collatedStringDecoder struct{} + +func (collatedStringDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { + return &tree.DCollatedString{Contents: string(v)}, nil +} + // Defeat the linter's unused lint errors. func init() { var _, _ = boolDecoder{}.decode(false) @@ -100,5 +244,23 @@ func init() { var _, _ = int64Decoder{}.decode(0) var _, _ = decimalDecoder{}.decode(parquet.ByteArray{}) var _, _ = timestampDecoder{}.decode(parquet.ByteArray{}) + var _, _ = timestampTZDecoder{}.decode(parquet.ByteArray{}) var _, _ = uUIDDecoder{}.decode(parquet.FixedLenByteArray{}) + var _, _ = iNetDecoder{}.decode(parquet.ByteArray{}) + var _, _ = jsonDecoder{}.decode(parquet.ByteArray{}) + var _, _ = bitDecoder{}.decode(parquet.ByteArray{}) + var _, _ = bytesDecoder{}.decode(parquet.ByteArray{}) + var _, _ = enumDecoder{}.decode(parquet.ByteArray{}) + var _, _ = dateDecoder{}.decode(parquet.ByteArray{}) + var _, _ = box2DDecoder{}.decode(parquet.ByteArray{}) + var _, _ = box2DDecoder{}.decode(parquet.ByteArray{}) + var _, _ = geographyDecoder{}.decode(parquet.ByteArray{}) + var _, _ = geometryDecoder{}.decode(parquet.ByteArray{}) + var _, _ = intervalDecoder{}.decode(parquet.ByteArray{}) + var _, _ = timeDecoder{}.decode(0) + var _, _ = timeTZDecoder{}.decode(parquet.ByteArray{}) + var _, _ = float64Decoder{}.decode(0.0) + var _, _ = float32Decoder{}.decode(0.0) + var _, _ = oidDecoder{}.decode(0) + var _, _ = collatedStringDecoder{}.decode(parquet.ByteArray{}) } diff --git a/pkg/util/parquet/schema.go b/pkg/util/parquet/schema.go index 0e817a7bde42..e68c1560a25e 100644 --- a/pkg/util/parquet/schema.go +++ b/pkg/util/parquet/schema.go @@ -168,7 +168,8 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c result.decoder = uUIDDecoder{} return result, nil case types.TimestampFamily: - // Note that all timestamp datums are in UTC: https://www.cockroachlabs.com/docs/stable/timestamp.html + // We do not use schema.TimestampLogicalType because the library will enforce + // a physical type of int64, which is not sufficient for CRDB timestamps. result.node, err = schema.NewPrimitiveNodeLogical(colName, repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, defaultTypeLength, defaultSchemaFieldID) @@ -178,6 +179,180 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c result.colWriter = scalarWriter(writeTimestamp) result.decoder = timestampDecoder{} return result, nil + case types.TimestampTZFamily: + // We do not use schema.TimestampLogicalType because the library will enforce + // a physical type of int64, which is not sufficient for CRDB timestamps. + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeTimestampTZ) + result.decoder = timestampTZDecoder{} + return result, nil + case types.INetFamily: + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeINet) + result.decoder = iNetDecoder{} + return result, nil + case types.JsonFamily: + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.JSONLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeJSON) + result.decoder = jsonDecoder{} + return result, nil + case types.BitFamily: + result.node, err = schema.NewPrimitiveNode(colName, + repetitions, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeBit) + result.decoder = bitDecoder{} + return result, nil + case types.BytesFamily: + result.node, err = schema.NewPrimitiveNode(colName, + repetitions, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeBytes) + result.decoder = bytesDecoder{} + return result, nil + case types.EnumFamily: + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.EnumLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeEnum) + result.decoder = enumDecoder{} + return result, nil + case types.DateFamily: + // We do not use schema.DateLogicalType because the library will enforce + // a physical type of int32, which is not sufficient for CRDB timestamps. + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeDate) + result.decoder = dateDecoder{} + return result, nil + case types.Box2DFamily: + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeBox2D) + result.decoder = box2DDecoder{} + return result, nil + case types.GeographyFamily: + result.node, err = schema.NewPrimitiveNode(colName, + repetitions, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeGeography) + result.decoder = geographyDecoder{} + return result, nil + case types.GeometryFamily: + result.node, err = schema.NewPrimitiveNode(colName, + repetitions, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeGeometry) + result.decoder = geometryDecoder{} + return result, nil + case types.IntervalFamily: + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeInterval) + result.decoder = intervalDecoder{} + return result, nil + case types.TimeFamily: + // CRDB stores time datums in microseconds, adjusted to UTC. + // See https://www.cockroachlabs.com/docs/stable/time.html. + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.NewTimeLogicalType(true, schema.TimeUnitMicros), parquet.Types.Int64, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeTime) + result.decoder = timeDecoder{} + return result, nil + case types.TimeTZFamily: + // We cannot use the schema.NewTimeLogicalType because it does not support + // timezones. + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeTimeTZ) + result.decoder = timeTZDecoder{} + return result, nil + case types.FloatFamily: + if typ.Oid() == oid.T_float4 { + result.node, err = schema.NewPrimitiveNode(colName, + repetitions, parquet.Types.Float, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeFloat32) + result.decoder = float32Decoder{} + return result, nil + } + result.node, err = schema.NewPrimitiveNode(colName, + repetitions, parquet.Types.Double, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeFloat64) + result.decoder = float64Decoder{} + return result, nil + case types.OidFamily: + result.node = schema.NewInt32Node(colName, repetitions, defaultSchemaFieldID) + result.colWriter = scalarWriter(writeOid) + result.decoder = oidDecoder{} + return result, nil + case types.CollatedStringFamily: + result.node, err = schema.NewPrimitiveNodeLogical(colName, + repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray, + defaultTypeLength, defaultSchemaFieldID) + if err != nil { + return result, err + } + result.colWriter = scalarWriter(writeCollatedString) + result.decoder = collatedStringDecoder{} + return result, nil case types.ArrayFamily: // Arrays for type T are represented by the following: // message schema { -- toplevel schema @@ -191,17 +366,20 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c // and [] when encoding. // There is more info about encoding arrays here: // https://arrow.apache.org/blog/2022/10/08/arrow-parquet-encoding-part-2/ - elementCol, err := makeColumn("element", typ.ArrayContents(), parquet.Repetitions.Optional) + elementCol, err := makeColumn("element", typ.ArrayContents(), + parquet.Repetitions.Optional) if err != nil { return result, err } innerListFields := []schema.Node{elementCol.node} - innerListNode, err := schema.NewGroupNode("list", parquet.Repetitions.Repeated, innerListFields, defaultSchemaFieldID) + innerListNode, err := schema.NewGroupNode("list", parquet.Repetitions.Repeated, + innerListFields, defaultSchemaFieldID) if err != nil { return result, err } outerListFields := []schema.Node{innerListNode} - result.node, err = schema.NewGroupNodeLogical(colName, parquet.Repetitions.Optional, outerListFields, schema.ListLogicalType{}, defaultSchemaFieldID) + result.node, err = schema.NewGroupNodeLogical(colName, parquet.Repetitions.Optional, + outerListFields, schema.ListLogicalType{}, defaultSchemaFieldID) if err != nil { return result, err } @@ -213,23 +391,8 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c result.colWriter = arrayWriter(scalarColWriter) result.typ = elementCol.typ return result, nil - - // TODO(#99028): implement support for the remaining types. - // case types.INetFamily: - // case types.JsonFamily: - // case types.FloatFamily: - // case types.BytesFamily: - // case types.BitFamily: - // case types.EnumFamily: - // case types.Box2DFamily: - // case types.GeographyFamily: - // case types.GeometryFamily: - // case types.DateFamily: - // case types.TimeFamily: - // case types.TimeTZFamily: - // case types.IntervalFamily: - // case types.TimestampTZFamily: default: - return result, pgerror.Newf(pgcode.FeatureNotSupported, "parquet export does not support the %v type", typ.Family()) + return result, pgerror.Newf(pgcode.FeatureNotSupported, + "parquet writer does not support the type family %v", typ.Family()) } } diff --git a/pkg/util/parquet/testutils.go b/pkg/util/parquet/testutils.go index 44428bc27bec..e87e5652c19a 100644 --- a/pkg/util/parquet/testutils.go +++ b/pkg/util/parquet/testutils.go @@ -17,8 +17,10 @@ import ( "github.com/apache/arrow/go/v11/parquet" "github.com/apache/arrow/go/v11/parquet/file" + "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/encoding" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -88,9 +90,15 @@ func ReadFileAndVerifyDatums( case parquet.Types.Int96: panic("unimplemented") case parquet.Types.Float: - panic("unimplemented") + arrs, read, err := readBatch(col, make([]float32, 1), dec, typ, isArray) + require.NoError(t, err) + require.Equal(t, rowsInRowGroup, read) + decodeValuesIntoDatumsHelper(arrs, readDatums, colIdx, startingRowIdx) case parquet.Types.Double: - panic("unimplemented") + arrs, read, err := readBatch(col, make([]float64, 1), dec, typ, isArray) + require.NoError(t, err) + require.Equal(t, rowsInRowGroup, read) + decodeValuesIntoDatumsHelper(arrs, readDatums, colIdx, startingRowIdx) case parquet.Types.ByteArray: colDatums, read, err := readBatch(col, make([]parquet.ByteArray, 1), dec, typ, isArray) require.NoError(t, err) @@ -192,6 +200,15 @@ func decodeValuesIntoDatumsHelper( } } +func unwrapDatum(d tree.Datum) tree.Datum { + switch t := d.(type) { + case *tree.DOidWrapper: + return unwrapDatum(t.Wrapped) + default: + return d + } +} + // validateDatum validates that the "contents" of the expected datum matches the // contents of the actual datum. For example, when validating two arrays, we // only compare the datums in the arrays. We do not compare CRDB-specific @@ -201,7 +218,29 @@ func decodeValuesIntoDatumsHelper( // used in an export use case, so we only need to test roundtripability with // data end users see. We do not need to check roundtripability of internal CRDB data. func validateDatum(t *testing.T, expected tree.Datum, actual tree.Datum) { + // The randgen library may generate datums wrapped in a *tree.DOidWrapper, so + // we should unwrap them. We unwrap at this stage as opposed to when + // generating datums to test that the writer can handle wrapped datums. + expected = unwrapDatum(expected) + switch expected.ResolvedType().Family() { + case types.JsonFamily: + require.Equal(t, expected.(*tree.DJSON).JSON.String(), + actual.(*tree.DJSON).JSON.String()) + case types.DateFamily: + require.Equal(t, expected.(*tree.DDate).Date.UnixEpochDays(), + actual.(*tree.DDate).Date.UnixEpochDays()) + case types.FloatFamily: + if expected.ResolvedType().Equal(types.Float4) && expected.(*tree.DFloat).String() != "NaN" { + // CRDB currently doesn't truncate non NAN float4's correctly, so this + // test does it manually :( + // https://github.com/cockroachdb/cockroach/issues/73743 + e := float32(*expected.(*tree.DFloat)) + a := float32(*expected.(*tree.DFloat)) + require.Equal(t, e, a) + } else { + require.Equal(t, expected.String(), actual.String()) + } case types.ArrayFamily: arr1 := expected.(*tree.DArray).Array arr2 := actual.(*tree.DArray).Array @@ -209,7 +248,31 @@ func validateDatum(t *testing.T, expected tree.Datum, actual tree.Datum) { for i := 0; i < len(arr1); i++ { validateDatum(t, arr1[i], arr2[i]) } + case types.EnumFamily: + require.Equal(t, expected.(*tree.DEnum).LogicalRep, actual.(*tree.DEnum).LogicalRep) + case types.CollatedStringFamily: + require.Equal(t, expected.(*tree.DCollatedString).Contents, actual.(*tree.DCollatedString).Contents) default: require.Equal(t, expected, actual) } } + +func makeTestingEnumType() *types.T { + enumMembers := []string{"hi", "hello"} + enumType := types.MakeEnum(catid.TypeIDToOID(500), catid.TypeIDToOID(100500)) + enumType.TypeMeta = types.UserDefinedTypeMetadata{ + Name: &types.UserDefinedTypeName{ + Schema: "test", + Name: "greeting", + }, + EnumData: &types.EnumMetadata{ + LogicalRepresentations: enumMembers, + PhysicalRepresentations: [][]byte{ + encoding.EncodeUntaggedIntValue(nil, 0), + encoding.EncodeUntaggedIntValue(nil, 1), + }, + IsMemberReadOnly: make([]bool, len(enumMembers)), + }, + } + return enumType +} diff --git a/pkg/util/parquet/write_functions.go b/pkg/util/parquet/write_functions.go index 874e177a6984..7801597199f4 100644 --- a/pkg/util/parquet/write_functions.go +++ b/pkg/util/parquet/write_functions.go @@ -40,32 +40,35 @@ type batchAlloc struct { boolBatch [1]bool int32Batch [1]int32 int64Batch [1]int64 + float32Batch [1]float32 + float64Batch [1]float64 byteArrayBatch [1]parquet.ByteArray fixedLenByteArrayBatch [1]parquet.FixedLenByteArray } // The following variables are used when writing datums which are not in arrays. // -// nonNilDefLevel represents a definition level of 1, meaning that the value is non-nil. -// nilDefLevel represents a definition level of 0, meaning that the value is nil. -// Any corresponding repetition level should be 0 as nonzero repetition levels are only valid for -// arrays in this library. +// nonNilDefLevel represents a definition level of 1, meaning that the value is +// non-nil. nilDefLevel represents a definition level of 0, meaning that the +// value is nil. Any corresponding repetition level should be 0 as nonzero +// repetition levels are only valid for arrays in this library. // // For more info on definition levels, refer to // https://arrow.apache.org/blog/2022/10/05/arrow-parquet-encoding-part-1/ var nonNilDefLevel = []int16{1} var nilDefLevel = []int16{0} -// The following variables are used when writing datums which are in arrays. This explanation -// is valid for the array schema constructed in makeColumn. +// The following variables are used when writing datums which are in arrays. +// This explanation is valid for the array schema constructed in makeColumn. // // In summary: -// - def level 0 means the array is null -// - def level 1 means the array is not null, but is empty. -// - def level 2 means the array is not null, and contains a null datum -// - def level 3 means the array is not null, and contains a non-null datum -// - rep level 0 indicates the start of a new array (which may be null or non-null depending on the def level) -// - rep level 1 indicates that we are writing to an existing array +// - def level 0 means the array is null +// - def level 1 means the array is not null, but is empty. +// - def level 2 means the array is not null, and contains a null datum +// - def level 3 means the array is not null, and contains a non-null datum +// - rep level 0 indicates the start of a new array (which may be null or +// non-null depending on the def level) +// - rep level 1 indicates that we are writing to an existing array // // Examples: // @@ -84,13 +87,17 @@ var nilDefLevel = []int16{0} // d.Array = tree.Datums{1, 2, NULL, 3, 4} // d2 := tree.NewDArray(types.Int) // d2.Array = tree.Datums{1, 1} -// writeFn(d.Array[0], ..., defLevels = [3], repLevels = [0]) -- repLevel 0 indicates the start of an array -// writeFn(d.Array[1], ..., defLevels = [3], repLevels = [1]) -- repLevel 1 writes the datum in the array -// writeFn(tree.DNull, ..., defLevels = [2], repLevels = [1]) -- defLevel 2 indicates a null datum +// writeFn(d.Array[0], ..., defLevels = [3], repLevels = [0]) +// (repLevel 0 indicates the start of an array) +// writeFn(d.Array[1], ..., defLevels = [3], repLevels = [1]) +// (repLevel 1 writes the datum in the array) +// writeFn(tree.DNull, ..., defLevels = [2], repLevels = [1]) +// (defLevel 2 indicates a null datum) // writeFn(d.Array[3], ..., defLevels = [3], repLevels = [1]) // writeFn(d.Array[4], ..., defLevels = [3], repLevels = [1]) // -// writeFn(d2.Array[0], ..., defLevels = [3], repLevels = [0]) -- repLevel 0 indicates the start of a new array +// writeFn(d2.Array[0], ..., defLevels = [3], repLevels = [0]) +// (repLevel 0 indicates the start of a new array) // writeFn(d2.Array[1], ..., defLevels = [3], repLevels = [1]) // // For more info on definition levels and repetition levels, refer to @@ -104,24 +111,28 @@ var arrayEntryNonNilDefLevel = []int16{3} // A colWriter is responsible for writing a datum to a file.ColumnChunkWriter. type colWriter interface { - Write(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc) error + Write(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, fmtCtx *tree.FmtCtx) error } type scalarWriter writeFn -func (w scalarWriter) Write(d tree.Datum, cw file.ColumnChunkWriter, a *batchAlloc) error { - return writeScalar(d, cw, a, writeFn(w)) +func (w scalarWriter) Write( + d tree.Datum, cw file.ColumnChunkWriter, a *batchAlloc, fmtCtx *tree.FmtCtx, +) error { + return writeScalar(d, cw, a, fmtCtx, writeFn(w)) } -func writeScalar(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, wFn writeFn) error { +func writeScalar( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, fmtCtx *tree.FmtCtx, wFn writeFn, +) error { if d == tree.DNull { - if err := wFn(tree.DNull, w, a, nilDefLevel, newEntryRepLevel); err != nil { + if err := wFn(tree.DNull, w, a, fmtCtx, nilDefLevel, newEntryRepLevel); err != nil { return err } return nil } - if err := wFn(d, w, a, nonNilDefLevel, newEntryRepLevel); err != nil { + if err := wFn(d, w, a, fmtCtx, nonNilDefLevel, newEntryRepLevel); err != nil { return err } return nil @@ -129,20 +140,24 @@ func writeScalar(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, wFn writ type arrayWriter writeFn -func (w arrayWriter) Write(d tree.Datum, cw file.ColumnChunkWriter, a *batchAlloc) error { - return writeArray(d, cw, a, writeFn(w)) +func (w arrayWriter) Write( + d tree.Datum, cw file.ColumnChunkWriter, a *batchAlloc, fmtCtx *tree.FmtCtx, +) error { + return writeArray(d, cw, a, fmtCtx, writeFn(w)) } -func writeArray(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, wFn writeFn) error { +func writeArray( + d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, fmtCtx *tree.FmtCtx, wFn writeFn, +) error { if d == tree.DNull { - return wFn(tree.DNull, w, a, nilArrayDefLevel, newEntryRepLevel) + return wFn(tree.DNull, w, a, nil, nilArrayDefLevel, newEntryRepLevel) } di, ok := tree.AsDArray(d) if !ok { return pgerror.Newf(pgcode.DatatypeMismatch, "expected DArray, found %T", d) } if len(di.Array) == 0 { - return wFn(tree.DNull, w, a, zeroLengthArrayDefLevel, newEntryRepLevel) + return wFn(tree.DNull, w, a, nil, zeroLengthArrayDefLevel, newEntryRepLevel) } repLevel := newEntryRepLevel @@ -151,11 +166,11 @@ func writeArray(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, wFn write repLevel = arrayEntryRepLevel } if childDatum == tree.DNull { - if err := wFn(childDatum, w, a, arrayEntryNilDefLevel, repLevel); err != nil { + if err := wFn(childDatum, w, a, fmtCtx, arrayEntryNilDefLevel, repLevel); err != nil { return err } } else { - if err := wFn(childDatum, w, a, arrayEntryNonNilDefLevel, repLevel); err != nil { + if err := wFn(childDatum, w, a, fmtCtx, arrayEntryNonNilDefLevel, repLevel); err != nil { return err } } @@ -163,12 +178,28 @@ func writeArray(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, wFn write return nil } -// A writeFn encodes a datum and writes it using the provided column chunk writer. -// The caller is responsible for ensuring that the def levels and rep levels are correct. -type writeFn func(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16) error +// A writeFn encodes a datum and writes it using the provided column chunk +// writer. The caller is responsible for ensuring that the def levels and rep +// levels are correct. +type writeFn func(d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16) error + +// formatDatum writes the datum into the parquet.ByteArray batch alloc using the +// tree.NodeFormatter interface. It is important that the fmtCtx remains open +// until after the bytes have been read from the batchAlloc, otherwise the byte +// slice may point to invalid data. +func formatDatum(d tree.Datum, a *batchAlloc, fmtCtx *tree.FmtCtx) { + fmtCtx.Reset() + d.Format(fmtCtx) + a.byteArrayBatch[0] = fmtCtx.Bytes() +} func writeInt32( - d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, ) error { if d == tree.DNull { return writeBatch[int32](w, a.int32Batch[:], defLevels, repLevels) @@ -182,7 +213,11 @@ func writeInt32( } func writeInt64( - d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, ) error { if d == tree.DNull { return writeBatch[int64](w, a.int64Batch[:], defLevels, repLevels) @@ -196,7 +231,11 @@ func writeInt64( } func writeBool( - d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, ) error { if d == tree.DNull { return writeBatch[bool](w, a.boolBatch[:], defLevels, repLevels) @@ -210,7 +249,11 @@ func writeBool( } func writeString( - d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, ) error { if d == tree.DNull { return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) @@ -256,7 +299,11 @@ func unsafeGetBytes(s string) ([]byte, error) { } func writeTimestamp( - d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, ) error { if d == tree.DNull { return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) @@ -267,15 +314,36 @@ func writeTimestamp( return pgerror.Newf(pgcode.DatatypeMismatch, "expected DTimestamp, found %T", d) } - fmtCtx := tree.NewFmtCtx(tree.FmtBareStrings) - d.Format(fmtCtx) + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} - a.byteArrayBatch[0] = parquet.ByteArray(fmtCtx.CloseAndGetString()) +func writeTimestampTZ( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + + _, ok := tree.AsDTimestampTZ(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DTimestampTZ, found %T", d) + } + + formatDatum(d, a, fmtCtx) return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } func writeUUID( - d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, ) error { if d == tree.DNull { return writeBatch[parquet.FixedLenByteArray](w, a.fixedLenByteArrayBatch[:], defLevels, repLevels) @@ -290,22 +358,334 @@ func writeUUID( } func writeDecimal( - d tree.Datum, w file.ColumnChunkWriter, a *batchAlloc, defLevels, repLevels []int16, + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, ) error { if d == tree.DNull { return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } - di, ok := tree.AsDDecimal(d) + _, ok := tree.AsDDecimal(d) if !ok { return pgerror.Newf(pgcode.DatatypeMismatch, "expected DDecimal, found %T", d) } - a.byteArrayBatch[0] = parquet.ByteArray(di.String()) + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeINet( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + _, ok := tree.AsDIPAddr(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DIPAddr, found %T", d) + } + + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeJSON( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + _, ok := tree.AsDJSON(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DJSON, found %T", d) + } + + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeBit( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + _, ok := tree.AsDBitArray(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DBitArray, found %T", d) + } + + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeBytes( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + di, ok := tree.AsDBytes(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DBytes, found %T", d) + } + b, err := unsafeGetBytes(string(di)) + if err != nil { + return err + } + + a.byteArrayBatch[0] = b + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeEnum( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + di, ok := tree.AsDEnum(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DEnum, found %T", d) + } + b, err := unsafeGetBytes(di.LogicalRep) + if err != nil { + return err + } + + a.byteArrayBatch[0] = b + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeDate( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + _, ok := tree.AsDDate(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DDate, found %T", d) + } + + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeBox2D( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + _, ok := tree.AsDBox2D(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DBox2D, found %T", d) + } + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeGeography( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + di, ok := tree.AsDGeography(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DGeography, found %T", d) + } + + a.byteArrayBatch[0] = parquet.ByteArray(di.EWKB()) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeGeometry( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + di, ok := tree.AsDGeometry(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DGeometry, found %T", d) + } + a.byteArrayBatch[0] = parquet.ByteArray(di.EWKB()) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeInterval( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + _, ok := tree.AsDInterval(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DInterval, found %T", d) + } + + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeTime( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[int64](w, a.int64Batch[:], defLevels, repLevels) + } + di, ok := tree.AsDTime(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DTime, found %T", d) + } + a.int64Batch[0] = int64(di) + return writeBatch[int64](w, a.int64Batch[:], defLevels, repLevels) +} + +func writeTimeTZ( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + fmtCtx *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + _, ok := tree.AsDTimeTZ(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DTimeTZ, found %T", d) + } + formatDatum(d, a, fmtCtx) + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) +} + +func writeFloat32( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[float32](w, a.float32Batch[:], defLevels, repLevels) + } + di, ok := tree.AsDFloat(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DFloat, found %T", d) + } + a.float32Batch[0] = float32(*di) + return writeBatch[float32](w, a.float32Batch[:], defLevels, repLevels) +} + +func writeFloat64( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[float64](w, a.float64Batch[:], defLevels, repLevels) + } + di, ok := tree.AsDFloat(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DFloat, found %T", d) + } + a.float64Batch[0] = float64(*di) + return writeBatch[float64](w, a.float64Batch[:], defLevels, repLevels) +} + +func writeOid( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[int32](w, a.int32Batch[:], defLevels, repLevels) + } + di, ok := tree.AsDOid(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DInt, found %T", d) + } + a.int32Batch[0] = int32(di.Oid) + return writeBatch[int32](w, a.int32Batch[:], defLevels, repLevels) +} + +func writeCollatedString( + d tree.Datum, + w file.ColumnChunkWriter, + a *batchAlloc, + _ *tree.FmtCtx, + defLevels, repLevels []int16, +) error { + if d == tree.DNull { + return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) + } + di, ok := tree.AsDCollatedString(d) + if !ok { + return pgerror.Newf(pgcode.DatatypeMismatch, "expected DInt, found %T", d) + } + b, err := unsafeGetBytes(di.Contents) + if err != nil { + return err + } + a.byteArrayBatch[0] = b return writeBatch[parquet.ByteArray](w, a.byteArrayBatch[:], defLevels, repLevels) } // parquetDatatypes are the physical types used in the parquet library. type parquetDatatypes interface { - bool | int32 | int64 | parquet.ByteArray | parquet.FixedLenByteArray + bool | int32 | int64 | float32 | float64 | parquet.ByteArray | parquet.FixedLenByteArray } // batchWriter is an interface representing parquet column chunk writers such as diff --git a/pkg/util/parquet/writer.go b/pkg/util/parquet/writer.go index b4e5527e1cbb..e26aa5d8eb1f 100644 --- a/pkg/util/parquet/writer.go +++ b/pkg/util/parquet/writer.go @@ -113,9 +113,14 @@ func (w *Writer) writeDatumToColChunk(d tree.Datum, colIdx int) error { return err } - if err = w.sch.cols[colIdx].colWriter.Write(d, cw, w.ba); err != nil { + // tree.NewFmtCtx uses an underlying pool, so we can assume there is no + // allocation here. + fmtCtx := tree.NewFmtCtx(tree.FmtExport) + defer fmtCtx.Close() + if err = w.sch.cols[colIdx].colWriter.Write(d, cw, w.ba, fmtCtx); err != nil { return err } + return nil } diff --git a/pkg/util/parquet/writer_bench_test.go b/pkg/util/parquet/writer_bench_test.go index fe185a6dafaf..85de72e616f8 100644 --- a/pkg/util/parquet/writer_bench_test.go +++ b/pkg/util/parquet/writer_bench_test.go @@ -11,56 +11,65 @@ package parquet import ( - "fmt" "math/rand" "os" "testing" - "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/stretchr/testify/require" ) // BenchmarkParquetWriter benchmarks the Writer.AddRow operation. -// TODO(jayant): add more types to this benchmark. func BenchmarkParquetWriter(b *testing.B) { rng := rand.New(rand.NewSource(timeutil.Now().UnixNano())) + numCols := 32 + benchmarkTypes := getBenchmarkTypes() - // Create a row size of 2KiB. - numCols := 16 - datumSizeBytes := 128 - sch := newColSchema(numCols) - for i := 0; i < numCols; i++ { - sch.columnTypes[i] = types.String - sch.columnNames[i] = fmt.Sprintf("col%d", i) - } - datums := make([]tree.Datum, numCols) - for i := 0; i < numCols; i++ { - p := make([]byte, datumSizeBytes) - _, _ = rng.Read(p) - tree.NewDBytes(tree.DBytes(p)) - datums[i] = tree.NewDString(string(p)) - } + for i, typ := range benchmarkTypes { + bench := func(b *testing.B) { + fileName := "BenchmarkParquetWriter.parquet" + f, err := os.CreateTemp("", fileName) + require.NoError(b, err) + + // Slice a single type out of supportedTypes. + sch := makeRandSchema(numCols, benchmarkTypes[i:i+1], rng) + datums := makeRandDatums(1, sch, rng) - fileName := "BenchmarkParquetWriter.parquet" - f, err := os.CreateTemp("", fileName) - require.NoError(b, err) + schemaDef, err := NewSchema(sch.columnNames, sch.columnTypes) + require.NoError(b, err) - schemaDef, err := NewSchema(sch.columnNames, sch.columnTypes) - require.NoError(b, err) + writer, err := NewWriter(schemaDef, f) + require.NoError(b, err) - writer, err := NewWriter(schemaDef, f) - require.NoError(b, err) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + err = writer.AddRow(datums[0]) + require.NoError(b, err) + } + b.StopTimer() - b.ResetTimer() - b.ReportAllocs() + err = writer.Close() + require.NoError(b, err) + } - for i := 0; i < b.N; i++ { - err = writer.AddRow(datums) - require.NoError(b, err) + b.Run(typ.Name(), bench) } +} - err = writer.Close() - require.NoError(b, err) +func getBenchmarkTypes() []*types.T { + var typs []*types.T + for _, typ := range supportedTypes { + switch typ.Family() { + case types.ArrayFamily: + // Pick out one array type to benchmark arrays. + if typ.ArrayContents() == types.Int { + typs = append(typs, typ) + } + default: + typs = append(typs, typ) + } + } + return typs } diff --git a/pkg/util/parquet/writer_test.go b/pkg/util/parquet/writer_test.go index a076eb08f5b6..e9a58576d38c 100644 --- a/pkg/util/parquet/writer_test.go +++ b/pkg/util/parquet/writer_test.go @@ -13,16 +13,22 @@ package parquet import ( "bytes" "fmt" + "math" "math/rand" "os" "testing" "time" "github.com/apache/arrow/go/v11/parquet/file" + "github.com/cockroachdb/cockroach/pkg/geo" "github.com/cockroachdb/cockroach/pkg/sql/randgen" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/bitarray" + "github.com/cockroachdb/cockroach/pkg/util/duration" + "github.com/cockroachdb/cockroach/pkg/util/ipaddr" "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil/pgdate" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/stretchr/testify/require" ) @@ -39,18 +45,30 @@ func newColSchema(numCols int) *colSchema { } } -// TODO (jayant): once all types are supported, we can use randgen.SeedTypes -// instead of this array. -var supportedTypes = []*types.T{ - types.Int, - types.Bool, - types.String, - types.Decimal, - types.Uuid, - types.Timestamp, -} +// supportedTypes contains all types supported by the writer, +// which is all types that pass randomized testing below. +var supportedTypes []*types.T func init() { + for _, typ := range randgen.SeedTypes { + switch typ.Family() { + // The types below are unsupported. They will fail randomized tests. + case types.AnyFamily: + case types.TSQueryFamily, types.TSVectorFamily: + case types.VoidFamily: + case types.TupleFamily: + case types.ArrayFamily: + // We will manually add array types which are supported below. + // Excluding types.TupleFamily and types.ArrayFamily leaves us with only + // scalar types so far. + default: + supportedTypes = append(supportedTypes, typ) + } + } + + // randgen.SeedTypes does not include types.Json, so we add it manually here. + supportedTypes = append(supportedTypes, types.Json) + // Include all array types which are arrays of the scalar types above. var arrayTypes []*types.T for oid := range types.ArrayOids { @@ -75,10 +93,10 @@ func makeRandDatums(numRows int, sch *colSchema, rng *rand.Rand) [][]tree.Datum return datums } -func makeRandSchema(numCols int, rng *rand.Rand) *colSchema { +func makeRandSchema(numCols int, allowedTypes []*types.T, rng *rand.Rand) *colSchema { sch := newColSchema(numCols) for i := 0; i < numCols; i++ { - sch.columnTypes[i] = supportedTypes[rng.Intn(len(supportedTypes))] + sch.columnTypes[i] = allowedTypes[rng.Intn(len(allowedTypes))] sch.columnNames[i] = fmt.Sprintf("%s%d", sch.columnTypes[i].Name(), i) } return sch @@ -93,7 +111,7 @@ func TestRandomDatums(t *testing.T) { numCols := 128 maxRowGroupSize := int64(8) - sch := makeRandSchema(numCols, rng) + sch := makeRandSchema(numCols, supportedTypes, rng) datums := makeRandDatums(numRows, sch, rng) fileName := "TestRandomDatums.parquet" @@ -162,6 +180,22 @@ func TestBasicDatums(t *testing.T) { }, nil }, }, + { + name: "timestamptz", + sch: &colSchema{ + columnTypes: []*types.T{types.TimestampTZ, types.TimestampTZ, types.TimestampTZ}, + columnNames: []string{"a", "b", "c"}, + }, + datums: func() ([][]tree.Datum, error) { + return [][]tree.Datum{ + { + tree.MustMakeDTimestampTZ(timeutil.Now(), time.Microsecond), + tree.MustMakeDTimestampTZ(timeutil.Now(), time.Microsecond), + tree.DNull, + }, + }, nil + }, + }, { name: "int", sch: &colSchema{ @@ -233,6 +267,192 @@ func TestBasicDatums(t *testing.T) { }, nil }, }, + { + name: "inet", + sch: &colSchema{ + columnTypes: []*types.T{types.INet, types.INet}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + var ipa ipaddr.IPAddr + err := ipaddr.ParseINet("192.168.2.1", &ipa) + require.NoError(t, err) + + return [][]tree.Datum{ + {&tree.DIPAddr{IPAddr: ipa}, tree.DNull}, + }, nil + }, + }, + { + name: "json", + sch: &colSchema{ + columnTypes: []*types.T{types.Json, types.Jsonb}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + j, err := tree.ParseDJSON("[{\"a\": 1}]") + require.NoError(t, err) + + return [][]tree.Datum{ + {j, tree.DNull}, + }, nil + }, + }, + { + name: "bitarray", + sch: &colSchema{ + columnTypes: []*types.T{types.VarBit, types.VarBit}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + ba, err := bitarray.Parse(string("101001")) + if err != nil { + return nil, err + } + + return [][]tree.Datum{ + {&tree.DBitArray{BitArray: ba}, tree.DNull}, + }, nil + }, + }, + { + name: "bytes", + sch: &colSchema{ + columnTypes: []*types.T{types.Bytes, types.Bytes}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + return [][]tree.Datum{ + {tree.NewDBytes("bytes"), tree.DNull}, + }, nil + }, + }, + { + name: "enum", + sch: &colSchema{ + columnTypes: []*types.T{makeTestingEnumType(), makeTestingEnumType()}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + d, err := tree.MakeDEnumFromLogicalRepresentation(makeTestingEnumType(), "hi") + require.NoError(t, err) + return [][]tree.Datum{ + {&d, tree.DNull}, + }, nil + }, + }, + { + name: "date", + sch: &colSchema{ + columnTypes: []*types.T{types.Date, types.Date}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + d, err := pgdate.MakeDateFromTime(timeutil.Now()) + require.NoError(t, err) + date := tree.MakeDDate(d) + return [][]tree.Datum{ + {&date, tree.DNull}, + }, nil + }, + }, + { + name: "box2d", + sch: &colSchema{ + columnTypes: []*types.T{types.Box2D, types.Box2D}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + b, err := geo.ParseCartesianBoundingBox("BOX(-0.4088850532348978 -0.19224841029808887,0.9334155753101069 0.7180433951296195)") + require.NoError(t, err) + return [][]tree.Datum{ + {tree.NewDBox2D(b), tree.DNull}, + }, nil + }, + }, + { + name: "geography", + sch: &colSchema{ + columnTypes: []*types.T{types.Geography, types.Geography}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + g, err := geo.ParseGeographyFromEWKB([]byte("\x01\x01\x00\x00\x20\xe6\x10\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x3f\x00\x00\x00\x00\x00\x00\xf0\x3f")) + require.NoError(t, err) + return [][]tree.Datum{ + {&tree.DGeography{Geography: g}, tree.DNull}, + }, nil + }, + }, + { + name: "geometry", + sch: &colSchema{ + columnTypes: []*types.T{types.Geometry, types.Geometry}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + g, err := geo.ParseGeometryFromEWKB([]byte("\x01\x01\x00\x00\x20\xe6\x10\x00\x00\x00\x00\x00\x00\x00\x00\xf0\x3f\x00\x00\x00\x00\x00\x00\xf0\x3f")) + require.NoError(t, err) + return [][]tree.Datum{ + {&tree.DGeometry{Geometry: g}, tree.DNull}, + }, nil + }, + }, + { + name: "interval", + sch: &colSchema{ + columnTypes: []*types.T{types.Interval, types.Interval}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + return [][]tree.Datum{ + {&tree.DInterval{Duration: duration.MakeDuration(0, 10, 15)}, tree.DNull}, + }, nil + }, + }, + { + name: "time", + sch: &colSchema{ + columnTypes: []*types.T{types.Time, types.Time}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + dt := tree.DTime(12345) + return [][]tree.Datum{ + {&dt, tree.DNull}, + }, nil + }, + }, + { + name: "timetz", + sch: &colSchema{ + columnTypes: []*types.T{types.TimeTZ, types.TimeTZ}, + columnNames: []string{"a", "b"}, + }, + datums: func() ([][]tree.Datum, error) { + dt := tree.NewDTimeTZFromTime(timeutil.Now()) + return [][]tree.Datum{ + {dt, tree.DNull}, + }, nil + }, + }, + { + name: "float", + sch: &colSchema{ + columnTypes: []*types.T{types.Float, types.Float, types.Float, types.Float, types.Float4, types.Float4, types.Float4}, + columnNames: []string{"a", "b", "c", "d", "e", "f", "g"}, + }, + datums: func() ([][]tree.Datum, error) { + d1 := tree.DFloat(math.MaxFloat64) + d2 := tree.DFloat(math.SmallestNonzeroFloat64) + d3 := tree.DFloat(math.NaN()) + d4 := tree.DFloat(math.MaxFloat32) + d5 := tree.DFloat(math.SmallestNonzeroFloat32) + return [][]tree.Datum{ + {&d1, &d2, &d3, tree.DNull, &d4, &d5, tree.DNull}, + }, nil + }, + }, } { t.Run(tc.name, func(t *testing.T) { datums, err := tc.datums()