diff --git a/pgtype/numeric.go b/pgtype/numeric.go index a5f4ed3ac..e33597039 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "math/big" @@ -233,13 +234,51 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return []byte("null"), nil } - if n.NaN { + switch { + case n.InfinityModifier == Infinity: + return []byte(`"infinity"`), nil + case n.InfinityModifier == NegativeInfinity: + return []byte(`"-infinity"`), nil + case n.NaN: return []byte(`"NaN"`), nil } return n.numberTextBytes(), nil } +func (n *Numeric) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *n = Numeric{} + return nil + } + + switch *s { + case "infinity": + *n = Numeric{NaN: true, InfinityModifier: Infinity, Valid: true} + case "-infinity": + *n = Numeric{NaN: true, InfinityModifier: -Infinity, Valid: true} + default: + num, exp, err := parseNumericString(*s) + if err != nil { + return fmt.Errorf("failed to decode %s to numeric: %w", *s, err) + } + + *n = Numeric{ + Int: num, + Exp: exp, + Valid: true, + } + } + + return nil +} + // numberString returns a string of the number. undefined if NaN, infinite, or NULL func (n Numeric) numberTextBytes() []byte { intStr := n.Int.String() diff --git a/pgtype/range.go b/pgtype/range.go index 8f408f9f3..ff6ba4288 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -3,6 +3,7 @@ package pgtype import ( "bytes" "encoding/binary" + "encoding/json" "fmt" ) @@ -320,3 +321,73 @@ func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error { r.Valid = true return nil } + +func (r Range[T]) MarshalJSON() ([]byte, error) { + if !r.Valid { + return []byte("null"), nil + } + + enc := encodePlanRangeCodecRangeValuerToText{ + m: &encodePlanRangeCodecJson{}, + } + + buf, err := enc.Encode(r, []byte(`"`)) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as range: %w", r, err) + } + + buf = append(buf, `"`...) + + return buf, nil +} + +func (r *Range[T]) UnmarshalJSON(b []byte) error { + if b[0] == byte('"') && b[len(b)-1] == byte('"') { + b = b[1 : len(b)-1] + } + + s := string(b) + + if s == "null" { + *r = Range[T]{} + return nil + } + + utr, err := parseUntypedTextRange(s) + if err != nil { + return fmt.Errorf("failed to decode %s to range: %w", s, err) + } + + *r = Range[T]{ + LowerType: utr.LowerType, + UpperType: utr.UpperType, + Valid: true, + } + + if r.LowerType == Empty && r.UpperType == Empty { + return nil + } + + if r.LowerType != Unbounded { + if err = r.unmarshalJSON(utr.Lower, &r.Lower); err != nil { + return fmt.Errorf("failed to decode %s to range lower: %w", utr.Lower, err) + } + } + + if r.UpperType != Unbounded { + if err = r.unmarshalJSON(utr.Upper, &r.Upper); err != nil { + return fmt.Errorf("failed to decode %s to range upper: %w", utr.Upper, err) + } + } + + return nil +} + +func (_ *Range[T]) unmarshalJSON(data string, v *T) error { + buf := make([]byte, 0, len(data)+2) + buf = append(buf, `"`...) + buf = append(buf, data...) + buf = append(buf, `"`...) + + return json.Unmarshal(buf, v) +} diff --git a/pgtype/range_codec.go b/pgtype/range_codec.go index 8cfb3a630..2cd19283c 100644 --- a/pgtype/range_codec.go +++ b/pgtype/range_codec.go @@ -2,6 +2,7 @@ package pgtype import ( "database/sql/driver" + "encoding/json" "fmt" "github.com/jackc/pgx/v5/internal/pgio" @@ -157,8 +158,10 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt } type encodePlanRangeCodecRangeValuerToText struct { - rc *RangeCodec - m *Map + rc Codec + m interface { + PlanEncode(oid uint32, format int16, value any) EncodePlan + } } func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) { @@ -182,12 +185,18 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) return nil, fmt.Errorf("unknown lower bound type %v", lowerType) } + var oid uint32 + + if rc, ok := plan.rc.(*RangeCodec); ok { + oid = rc.ElementType.OID + } + if lowerType != Unbounded { if lower == nil { return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded") } - lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower) + lowerPlan := plan.m.PlanEncode(oid, TextFormatCode, lower) if lowerPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", lower) } @@ -208,7 +217,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded") } - upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper) + upperPlan := plan.m.PlanEncode(oid, TextFormatCode, upper) if upperPlan == nil { return nil, fmt.Errorf("cannot encode %v as element of range", upper) } @@ -377,3 +386,24 @@ func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) ( err := c.PlanScan(m, oid, format, &r).Scan(src, &r) return r, err } + +type encodePlanRangeCodecJson struct{} + +func (s *encodePlanRangeCodecJson) PlanEncode(_ uint32, _ int16, _ any) EncodePlan { + return s +} + +func (s *encodePlanRangeCodecJson) Encode(value any, buf []byte) (newBuf []byte, err error) { + b, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("failed to encode %v as element of range: %v", value, err) + } + + if b[0] == byte('"') && b[len(b)-1] == byte('"') { + buf = append(buf, b[1:len(b)-1]...) + } else { + buf = append(buf, b...) + } + + return buf, nil +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go index 1ee8d5533..99fe03e7c 100644 --- a/pgtype/range_test.go +++ b/pgtype/range_test.go @@ -2,7 +2,9 @@ package pgtype import ( "bytes" + "math/big" "testing" + "time" ) func TestParseUntypedTextRange(t *testing.T) { @@ -175,3 +177,345 @@ func TestParseUntypedBinaryRange(t *testing.T) { } } } + +func TestRangeDateMarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + src Range[Date] + result string + }{ + {src: Range[Date]{}, result: "null"}, + {src: Range[Date]{ + LowerType: Empty, + UpperType: Empty, + Valid: true, + }, result: `"empty"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"(2022-12-01,2022-12-31)"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Inclusive, + Valid: true, + }, result: `"(2022-12-01,2022-12-31]"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"[2022-12-01,2022-12-31)"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Inclusive, + Valid: true, + }, result: `"[2022-12-01,2022-12-31]"`}, + {src: Range[Date]{ + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Unbounded, + UpperType: Exclusive, + Valid: true, + }, result: `"(,2022-12-31)"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Unbounded, + Valid: true, + }, result: `"[2022-12-01,)"`}, + {src: Range[Date]{ + Lower: Date{InfinityModifier: NegativeInfinity, Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"(-infinity,2022-12-31)"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{InfinityModifier: Infinity, Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"[2022-12-31,infinity)"`}, + } + + for i, tt := range tests { + r, err := tt.src.MarshalJSON() + if err != nil { + t.Fatalf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to encode to %v, got %v", i, tt.src, tt.result, string(r)) + } + } +} + +func TestRangeNumericMarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + src Range[Numeric] + result string + }{ + {src: Range[Numeric]{}, result: "null"}, + {src: Range[Numeric]{ + LowerType: Empty, + UpperType: Empty, + Valid: true, + }, result: `"empty"`}, + {src: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"(-16,16)"`}, + {src: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Inclusive, + Valid: true, + }, result: `"(-16,16]"`}, + {src: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"[-16,16)"`}, + {src: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Inclusive, + UpperType: Inclusive, + Valid: true, + }, result: `"[-16,16]"`}, + {src: Range[Numeric]{ + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Unbounded, + UpperType: Exclusive, + Valid: true, + }, result: `"(,16)"`}, + {src: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + LowerType: Inclusive, + UpperType: Unbounded, + Valid: true, + }, result: `"[-16,)"`}, + {src: Range[Numeric]{ + Lower: Numeric{InfinityModifier: NegativeInfinity, NaN: true, Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"(-infinity,16)"`}, + {src: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{InfinityModifier: Infinity, NaN: true, Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"[-16,infinity)"`}, + } + + for i, tt := range tests { + r, err := tt.src.MarshalJSON() + if err != nil { + t.Fatalf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to encode to %v, got %v", i, tt.src, tt.result, string(r)) + } + } +} + +func TestRangeDateUnmarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + src string + result Range[Date] + }{ + {src: "null", result: Range[Date]{}}, + {src: `"empty"`, result: Range[Date]{ + LowerType: Empty, + UpperType: Empty, + Valid: true, + }}, + {src: `"(2022-12-01,2022-12-31)"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"(2022-12-01,2022-12-31]"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Inclusive, + Valid: true, + }}, + {src: `"[2022-12-01,2022-12-31)"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[2022-12-01,2022-12-31]"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Inclusive, + Valid: true, + }}, + {src: `"(,2022-12-31)"`, result: Range[Date]{ + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Unbounded, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[2022-12-01,)"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Unbounded, + Valid: true, + }}, + {src: `"(-infinity,2022-12-31)"`, result: Range[Date]{ + Lower: Date{InfinityModifier: NegativeInfinity, Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[2022-12-31,infinity)"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{InfinityModifier: Infinity, Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }}, + } + + for i, tt := range tests { + var r Range[Date] + err := r.UnmarshalJSON([]byte(tt.src)) + if err != nil { + t.Fatalf("%d: %v", i, err) + } + + if r.Lower.Time.Year() != tt.result.Lower.Time.Year() || + r.Lower.Time.Month() != tt.result.Lower.Time.Month() || + r.Lower.Time.Day() != tt.result.Lower.Time.Day() || + r.Lower.InfinityModifier != tt.result.Lower.InfinityModifier || + r.LowerType != tt.result.LowerType || + r.Upper.Time.Year() != tt.result.Upper.Time.Year() || + r.Upper.Time.Month() != tt.result.Upper.Time.Month() || + r.Upper.Time.Day() != tt.result.Upper.Time.Day() || + r.Upper.InfinityModifier != tt.result.Upper.InfinityModifier || + r.UpperType != tt.result.UpperType || + r.Valid != tt.result.Valid { + t.Errorf("%d: expected %v to decode to %v, got %v", i, tt.src, tt.result, r) + } + } +} + +func TestRangeNumericUnmarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + src string + result Range[Numeric] + }{ + {src: "null", result: Range[Numeric]{}}, + {src: `"empty"`, result: Range[Numeric]{ + LowerType: Empty, + UpperType: Empty, + Valid: true, + }}, + {src: `"(-16,16)"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"(-16,16]"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Inclusive, + Valid: true, + }}, + {src: `"[-16,16)"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[-16,16]"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Inclusive, + UpperType: Inclusive, + Valid: true, + }}, + {src: `"(,16)"`, result: Range[Numeric]{ + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Unbounded, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[-16,)"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + LowerType: Inclusive, + UpperType: Unbounded, + Valid: true, + }}, + {src: `"(-infinity,16)"`, result: Range[Numeric]{ + Lower: Numeric{InfinityModifier: NegativeInfinity, NaN: true, Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[-16,infinity)"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{InfinityModifier: Infinity, NaN: true, Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }}, + } + + for i, tt := range tests { + var r Range[Numeric] + err := r.UnmarshalJSON([]byte(tt.src)) + if err != nil { + t.Fatalf("%d: %v", i, err) + } + + if r.Lower.Int.Cmp(tt.result.Lower.Int) != 0 || + r.Lower.InfinityModifier != tt.result.Lower.InfinityModifier || + r.LowerType != tt.result.LowerType || + r.Upper.Int.Cmp(tt.result.Upper.Int) != 0 || + r.Upper.InfinityModifier != tt.result.Upper.InfinityModifier || + r.UpperType != tt.result.UpperType || + r.Valid != r.Valid { + t.Errorf("%d: expected %s to decode to %v, got %v", i, tt.src, tt.result, r) + } + } +}