From 6c64c600eef83bdc8904527abaa311ab2a611c0b Mon Sep 17 00:00:00 2001 From: Chris Seto Date: Thu, 7 Jan 2021 18:30:56 -0500 Subject: [PATCH] sql: Add support for numeric JSON scalar casts Previously, casting from a JSON numeric value to numeric values was unspported. This commit adds support for this feature allowing queries such as: `SELECT '1'::jsonb::int` `SELECT '1'::jsonb::float` `SELECT '3.14'::jsonb::decimal` to work as expected. Fixes #41333 Release note (sql change): Casting JSON numeric scalars to numeric types now works as expected. --- pkg/sql/sem/tree/casts.go | 24 ++++++++++ pkg/sql/sem/tree/testdata/eval/cast | 71 +++++++++++++++++++++++++++++ pkg/util/json/encoded.go | 13 ++++++ pkg/util/json/json.go | 15 ++++++ pkg/util/json/json_test.go | 59 ++++++++++++++++++++++++ 5 files changed, 182 insertions(+) diff --git a/pkg/sql/sem/tree/casts.go b/pkg/sql/sem/tree/casts.go index 9a7070c73c63..113e2c6799a4 100644 --- a/pkg/sql/sem/tree/casts.go +++ b/pkg/sql/sem/tree/casts.go @@ -98,6 +98,7 @@ var validCasts = []castInfo{ {from: types.IntervalFamily, to: types.IntFamily, volatility: VolatilityImmutable}, {from: types.OidFamily, to: types.IntFamily, volatility: VolatilityImmutable}, {from: types.BitFamily, to: types.IntFamily, volatility: VolatilityImmutable}, + {from: types.JsonFamily, to: types.IntFamily, volatility: VolatilityImmutable}, // Casts to FloatFamily. {from: types.UnknownFamily, to: types.FloatFamily, volatility: VolatilityImmutable}, @@ -111,6 +112,7 @@ var validCasts = []castInfo{ {from: types.TimestampTZFamily, to: types.FloatFamily, volatility: VolatilityImmutable}, {from: types.DateFamily, to: types.FloatFamily, volatility: VolatilityImmutable}, {from: types.IntervalFamily, to: types.FloatFamily, volatility: VolatilityImmutable}, + {from: types.JsonFamily, to: types.FloatFamily, volatility: VolatilityImmutable}, // Casts to Box2D Family. {from: types.UnknownFamily, to: types.Box2DFamily, volatility: VolatilityImmutable}, @@ -150,6 +152,7 @@ var validCasts = []castInfo{ {from: types.TimestampTZFamily, to: types.DecimalFamily, volatility: VolatilityImmutable}, {from: types.DateFamily, to: types.DecimalFamily, volatility: VolatilityImmutable}, {from: types.IntervalFamily, to: types.DecimalFamily, volatility: VolatilityImmutable}, + {from: types.JsonFamily, to: types.DecimalFamily, volatility: VolatilityImmutable}, // Casts to StringFamily. {from: types.UnknownFamily, to: types.StringFamily, volatility: VolatilityImmutable}, @@ -697,6 +700,13 @@ func performCastWithoutPrecisionTruncation(ctx *EvalContext, d Datum, t *types.T res = NewDInt(DInt(iv)) case *DOid: res = &v.DInt + case *DJSON: + if dec, ok := v.AsDecimal(); ok { + asInt, err := dec.Int64() + if err == nil { + res = NewDInt(DInt(asInt)) + } + } } if res != nil { return res, nil @@ -747,6 +757,14 @@ func performCastWithoutPrecisionTruncation(ctx *EvalContext, d Datum, t *types.T return NewDFloat(DFloat(float64(v.UnixEpochDays()))), nil case *DInterval: return NewDFloat(DFloat(v.AsFloat64())), nil + case *DJSON: + if dec, ok := v.AsDecimal(); ok { + fl, err := dec.Float64() + if err != nil { + return nil, ErrFloatOutOfRange + } + return NewDFloat(DFloat(fl)), nil + } } case types.DecimalFamily: @@ -795,6 +813,12 @@ func performCastWithoutPrecisionTruncation(ctx *EvalContext, d Datum, t *types.T case *DInterval: v.AsBigInt(&dd.Coeff) dd.Exponent = -9 + case *DJSON: + if dec, ok := v.AsDecimal(); ok { + dd.Set(dec) + } else { + unset = false + } default: unset = true } diff --git a/pkg/sql/sem/tree/testdata/eval/cast b/pkg/sql/sem/tree/testdata/eval/cast index f15b610e7a19..424cf707d6ec 100644 --- a/pkg/sql/sem/tree/testdata/eval/cast +++ b/pkg/sql/sem/tree/testdata/eval/cast @@ -1132,3 +1132,74 @@ eval 'hello t'::string::char(100) ---- 'hello t' + +# Test that numeric jsonb values can be cast to a numeric data type +eval +'1'::jsonb::int +---- +1 + +eval +'1'::jsonb::float +---- +1.0 + +eval +'1'::jsonb::decimal +---- +1 + +eval +'1'::jsonb::string +---- +'1' + +eval +'2.0'::jsonb::int +---- +2 + +eval +'2.0'::jsonb::float +---- +2.0 + +eval +'2.0'::jsonb::decimal +---- +2.0 + +eval +'2.0'::jsonb::string +---- +'2.0' + +eval +'3.14'::jsonb::float +---- +3.14 + +eval +'3.14'::jsonb::decimal +---- +3.14 + +eval +'true'::jsonb::float +---- +invalid cast: jsonb -> float + +eval +'null'::jsonb::float +---- +invalid cast: jsonb -> float + +eval +'{}'::jsonb::float +---- +invalid cast: jsonb -> float + +eval +'[]'::jsonb::float +---- +invalid cast: jsonb -> float diff --git a/pkg/util/json/encoded.go b/pkg/util/json/encoded.go index 26a1e7faf1b0..645b358c461a 100644 --- a/pkg/util/json/encoded.go +++ b/pkg/util/json/encoded.go @@ -17,6 +17,7 @@ import ( "strconv" "unsafe" + "github.com/cockroachdb/apd/v2" "github.com/cockroachdb/cockroach/pkg/sql/inverted" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/errors" @@ -558,6 +559,18 @@ func (j *jsonEncoded) AsText() (*string, error) { return decoded.AsText() } +func (j *jsonEncoded) AsDecimal() (*apd.Decimal, bool) { + if dec := j.alreadyDecoded(); dec != nil { + return dec.AsDecimal() + } + + decoded, err := j.decode() + if err != nil { + return nil, false + } + return decoded.AsDecimal() +} + func (j *jsonEncoded) Compare(other JSON) (int, error) { if cmp := cmpJSONTypes(j.Type(), other.Type()); cmp != 0 { return cmp, nil diff --git a/pkg/util/json/json.go b/pkg/util/json/json.go index ae87c2dade9d..ee78de15d43b 100644 --- a/pkg/util/json/json.go +++ b/pkg/util/json/json.go @@ -136,6 +136,10 @@ type JSON interface { // AsText returns the JSON document as a string, with quotes around strings removed, and null as nil. AsText() (*string, error) + // AsDecimal returns the JSON document as a apd.Decimal if it is a numeric + // type, and a boolean inidicating if this JSON document is a numeric type. + AsDecimal() (*apd.Decimal, bool) + // Exists implements the `?` operator. Exists(string) (bool, error) @@ -412,6 +416,17 @@ func (j jsonString) MaybeDecode() JSON { return j } func (j jsonArray) MaybeDecode() JSON { return j } func (j jsonObject) MaybeDecode() JSON { return j } +func (j jsonNull) AsDecimal() (*apd.Decimal, bool) { return nil, false } +func (j jsonFalse) AsDecimal() (*apd.Decimal, bool) { return nil, false } +func (j jsonTrue) AsDecimal() (*apd.Decimal, bool) { return nil, false } +func (j jsonString) AsDecimal() (*apd.Decimal, bool) { return nil, false } +func (j jsonArray) AsDecimal() (*apd.Decimal, bool) { return nil, false } +func (j jsonObject) AsDecimal() (*apd.Decimal, bool) { return nil, false } +func (j jsonNumber) AsDecimal() (*apd.Decimal, bool) { + d := apd.Decimal(j) + return &d, true +} + func (j jsonNull) tryDecode() (JSON, error) { return j, nil } func (j jsonFalse) tryDecode() (JSON, error) { return j, nil } func (j jsonTrue) tryDecode() (JSON, error) { return j, nil } diff --git a/pkg/util/json/json_test.go b/pkg/util/json/json_test.go index 50d079becc9f..5f8db68e119f 100644 --- a/pkg/util/json/json_test.go +++ b/pkg/util/json/json_test.go @@ -2252,3 +2252,62 @@ func TestJSONRemovePath(t *testing.T) { } } } + +func TestToDecimal(t *testing.T) { + numericCases := []string{ + "1", + "1.0", + "3.14", + "-3.14", + "1.000", + "-0.0", + "-0.09", + "0.08", + } + + nonNumericCases := []string{ + "\"1\"", + "{}", + "[]", + "true", + "false", + "null", + } + + for _, tc := range numericCases { + t.Run(fmt.Sprintf("numeric - %s", tc), func(t *testing.T) { + dec1, _, err := apd.NewFromString(tc) + if err != nil { + t.Fatal(err) + } + + json, err := ParseJSON(tc) + if err != nil { + t.Fatal(err) + } + + dec2, ok := json.AsDecimal() + if !ok { + t.Fatalf("could not cast %v to decmial", json) + } + + if dec1.Cmp(dec2) != 0 { + t.Fatalf("expected %s == %s", dec1.String(), dec2.String()) + } + }) + } + + for _, tc := range nonNumericCases { + t.Run(fmt.Sprintf("nonNumeric - %s", tc), func(t *testing.T) { + json, err := ParseJSON(tc) + if err != nil { + t.Fatalf("expected no error") + } + + dec, ok := json.AsDecimal() + if dec != nil || ok { + t.Fatalf("%v should not be a valid decimal", json) + } + }) + } +}