Skip to content

Commit

Permalink
sql: Add support for numeric JSON scalar casts
Browse files Browse the repository at this point in the history
* Fixes cockroachdb#41333
* Add support for queries such as
  - `SELECT '1'::jsonb::int`
  - `SELECT '1'::jsonb::float`
  - `SELECT '3.14'::jsonb::decimal`

Release note: None
  • Loading branch information
chrisseto committed Oct 6, 2019
1 parent 1c34e52 commit dcf9ee3
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 3 deletions.
33 changes: 33 additions & 0 deletions pkg/sql/sem/tree/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3213,6 +3213,18 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) {
res = NewDInt(DInt(iv))
case *DOid:
res = &v.DInt
case *DJSON:
if v.Type() == json.NumberJSONType {
// err is ignored as a more appropriate error
// will be generated later
dec, err := json.ToDecimal(v.JSON)
if err == nil {
asInt, err := dec.Int64()
if err == nil {
res = NewDInt(DInt(asInt))
}
}
}
}
if res != nil {
return res, nil
Expand Down Expand Up @@ -3253,6 +3265,18 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) {
return NewDFloat(DFloat(float64(v.UnixEpochDays()))), nil
case *DInterval:
return NewDFloat(DFloat(v.AsFloat64())), nil
case *DJSON:
if v.Type() == json.NumberJSONType {
// err is ignored as a more appropriate error
// will be generated later
dec, err := json.ToDecimal(v.JSON)
if err == nil {
asFloat, err := dec.Int64()
if err == nil {
return NewDFloat(DFloat(asFloat)), nil
}
}
}
}

case types.DecimalFamily:
Expand Down Expand Up @@ -3301,6 +3325,15 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) {
case *DInterval:
v.AsBigInt(&dd.Coeff)
dd.Exponent = -9
case *DJSON:
if v.Type() == json.NumberJSONType {
var dec apd.Decimal
// err is ignored as a more appropriate error
// will be generated later
dec, err = json.ToDecimal(v.JSON)
unset = err != nil
dd = DDecimal{dec}
}
default:
unset = true
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/sem/tree/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -1520,11 +1520,11 @@ var (
bitArrayCastTypes = annotateCast(types.VarBit, []*types.T{types.Unknown, types.VarBit, types.Int, types.String, types.AnyCollatedString})
boolCastTypes = annotateCast(types.Bool, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString})
intCastTypes = annotateCast(types.Int, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString,
types.Timestamp, types.TimestampTZ, types.Date, types.Interval, types.Oid, types.VarBit})
types.Timestamp, types.TimestampTZ, types.Date, types.Interval, types.Oid, types.VarBit, types.Jsonb})
floatCastTypes = annotateCast(types.Float, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString,
types.Timestamp, types.TimestampTZ, types.Date, types.Interval})
types.Timestamp, types.TimestampTZ, types.Date, types.Interval, types.Jsonb})
decimalCastTypes = annotateCast(types.Decimal, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString,
types.Timestamp, types.TimestampTZ, types.Date, types.Interval})
types.Timestamp, types.TimestampTZ, types.Date, types.Interval, types.Jsonb})
stringCastTypes = annotateCast(types.String, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString,
types.VarBit,
types.AnyArray, types.AnyTuple,
Expand Down
55 changes: 55 additions & 0 deletions pkg/sql/sem/tree/testdata/eval/cast
Original file line number Diff line number Diff line change
Expand Up @@ -977,3 +977,58 @@ eval
ARRAY['hello','world']::char(2)[]
----
ARRAY['he','wo']

eval
'1'::jsonb::int
----
1

eval
'1'::jsonb::float
----
1.0

eval
'1'::jsonb::decimal
----
1

eval
'1'::jsonb::string
----
'1'

eval
'2.0'::jsonb::int
----
2

eval
'2.0'::jsonb::float
----
2.0

eval
'2.0'::jsonb::decimal
----
2.0

eval
'2.0'::jsonb::string
----
'2.0'

eval
'3.14'::jsonb::float
----
3.14

eval
'3.14'::jsonb::decimal
----
3.14

eval
'3.14'::jsonb::string
----
'3.14'
13 changes: 13 additions & 0 deletions pkg/util/json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,19 @@ func FromFloat64(v float64) (JSON, error) {
return jsonNumber(dec), nil
}

// ToDecimal returns a apd.Decimal given a JSON value
func ToDecimal(j JSON) (apd.Decimal, error) {
j, err := decodeIfNeeded(j)
if err != nil {
return apd.Decimal{}, err
}
num, ok := j.(jsonNumber)
if !ok {
return apd.Decimal{}, errors.AssertionFailedf("cannot convert JSON of type %T to decimal", j)
}
return apd.Decimal(num), nil
}

// MakeJSON returns a JSON value given a Go-style representation of JSON.
// * JSON null is Go `nil`,
// * JSON true is Go `true`,
Expand Down
62 changes: 62 additions & 0 deletions pkg/util/json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2016,3 +2016,65 @@ func TestJSONRemovePath(t *testing.T) {
}
}
}

func TestToDecimal(t *testing.T) {
numericCases := []string{
"1",
"1.0",
"3.14",
"-3.14",
"1.000",
"-0.0",
"-0.09",
"0.08",
}

nonNumericCases := []struct {
input string
errMsg string
}{
{"\"1\"", "cannot convert JSON of type json.jsonString to decimal"},
{"{}", "cannot convert JSON of type json.jsonObject to decimal"},
{"[]", "cannot convert JSON of type json.jsonArray to decimal"},
{"true", "cannot convert JSON of type json.jsonTrue to decimal"},
{"false", "cannot convert JSON of type json.jsonFalse to decimal"},
{"null", "cannot convert JSON of type json.jsonNull to decimal"},
}

for _, tc := range numericCases {
t.Run(fmt.Sprintf("numeric - %s", tc), func(t *testing.T) {
dec1, _, err := apd.NewFromString(tc)
if err != nil {
t.Fatal(err)
}

json, err := ParseJSON(tc)
if err != nil {
t.Fatal(err)
}

dec2, err := ToDecimal(json)
if err != nil {
t.Fatal(err)
}

if dec1.Cmp(&dec2) != 0 {
t.Fatalf("expected %s == %s", dec1.String(), dec2.String())
}
})
}

for _, tc := range nonNumericCases {
t.Run(fmt.Sprintf("nonNumeric - %s", tc), func(t *testing.T) {
json, err := ParseJSON(tc.input)
if err != nil {
t.Fatalf("expected no error")
}

_, err = ToDecimal(json)
if err.Error() != tc.errMsg {
t.Fatalf("expected %s, got %s", tc.errMsg, err.Error())
}
})
}
}

0 comments on commit dcf9ee3

Please sign in to comment.