diff --git a/pgtype/range.go b/pgtype/range.go index 8f408f9f3..53967c4b6 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -3,7 +3,10 @@ package pgtype import ( "bytes" "encoding/binary" + "encoding/json" + "errors" "fmt" + "strings" ) type BoundType byte @@ -320,3 +323,102 @@ func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error { r.Valid = true return nil } + +func (src Range[T]) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + buf := bytes.NewBufferString("") + + switch src.LowerType { + case Inclusive: + buf.WriteRune('[') + case Exclusive: + buf.WriteRune('(') + } + + lower, err := json.Marshal(src.Lower) + if err != nil { + return nil, err + } + + buf.WriteString(strings.TrimSuffix(strings.TrimPrefix(string(lower), `"`), `"`)) + buf.WriteRune(',') + + upper, err := json.Marshal(src.Upper) + if err != nil { + return nil, err + } + + buf.WriteString(strings.TrimSuffix(strings.TrimPrefix(string(upper), `"`), `"`)) + + switch src.UpperType { + case Inclusive: + buf.WriteRune(']') + case Exclusive: + buf.WriteRune(')') + } + + return json.Marshal(buf.String()) +} + +func (dst *Range[T]) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Range[T]{} + return nil + } + + parts := strings.SplitN(*s, ",", 2) + if len(parts) != 2 { + return errors.New("bad range format") + } + + lowerBound := string(parts[0][0]) + lower := parts[0][1:] + + var lowerDst T + if err = json.Unmarshal([]byte(`"`+lower+`"`), &lowerDst); err != nil { + return fmt.Errorf("bad lower: %w", err) + } + + upper := parts[1][0 : len(parts[1])-1] + upperBound := parts[1][len(parts[1])-1:] + + var upperDst T + if err = json.Unmarshal([]byte(`"`+upper+`"`), &upperDst); err != nil { + return fmt.Errorf("bad upper: %w", err) + } + + *dst = Range[T]{ + Lower: lowerDst, + Upper: upperDst, + Valid: true, + } + + switch lowerBound { + case "(": + dst.LowerType = Exclusive + case "[": + dst.LowerType = Inclusive + default: + return fmt.Errorf("lower bound %q not implemented", lowerBound) + } + + switch upperBound { + case ")": + dst.UpperType = Exclusive + case "]": + dst.UpperType = Inclusive + default: + return fmt.Errorf("upper bound %q not implemented", upperBound) + } + + return nil +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go index 1ee8d5533..702c0dd7f 100644 --- a/pgtype/range_test.go +++ b/pgtype/range_test.go @@ -3,6 +3,7 @@ package pgtype import ( "bytes" "testing" + "time" ) func TestParseUntypedTextRange(t *testing.T) { @@ -175,3 +176,142 @@ func TestParseUntypedBinaryRange(t *testing.T) { } } } + +func TestRangeDateMarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + src Range[Date] + result string + }{ + {src: Range[Date]{}, result: "null"}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"(2022-12-01,2022-12-31)"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Inclusive, + Valid: true, + }, result: `"(2022-12-01,2022-12-31]"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }, result: `"[2022-12-01,2022-12-31)"`}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Inclusive, + Valid: true, + }, result: `"[2022-12-01,2022-12-31]"`}, + {src: Range[Date]{ + Lower: Date{InfinityModifier: NegativeInfinity, Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }, result: "\"(-infinity,2022-12-31)\""}, + {src: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{InfinityModifier: Infinity, Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }, result: "\"[2022-12-31,infinity)\""}, + } + + for i, tt := range tests { + r, err := tt.src.MarshalJSON() + if err != nil { + t.Fatalf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to encode to %v, got %v", i, tt.src, tt.result, string(r)) + } + } +} + +func TestRangeDateUnmarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + src string + result Range[Date] + }{ + {src: "null", result: Range[Date]{}}, + {src: `"(2022-12-01,2022-12-31)"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"(2022-12-01,2022-12-31]"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Inclusive, + Valid: true, + }}, + {src: `"[2022-12-01,2022-12-31)"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[2022-12-01,2022-12-31]"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 1, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Inclusive, + UpperType: Inclusive, + Valid: true, + }}, + {src: `"(-infinity,2022-12-31)"`, result: Range[Date]{ + Lower: Date{InfinityModifier: NegativeInfinity, Valid: true}, + Upper: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[2022-12-31,infinity)"`, result: Range[Date]{ + Lower: Date{Time: time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC), Valid: true}, + Upper: Date{InfinityModifier: Infinity, Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }}, + } + + for i, tt := range tests { + var r Range[Date] + err := r.UnmarshalJSON([]byte(tt.src)) + if err != nil { + t.Fatalf("%d: %v", i, err) + } + + if r.Lower.Time.Year() != tt.result.Lower.Time.Year() || + r.Lower.Time.Month() != tt.result.Lower.Time.Month() || + r.Lower.Time.Day() != tt.result.Lower.Time.Day() || + r.Lower.InfinityModifier != tt.result.Lower.InfinityModifier || + r.LowerType != tt.result.LowerType || + r.Upper.Time.Year() != tt.result.Upper.Time.Year() || + r.Upper.Time.Month() != tt.result.Upper.Time.Month() || + r.Upper.Time.Day() != tt.result.Upper.Time.Day() || + r.Upper.InfinityModifier != tt.result.Upper.InfinityModifier || + r.UpperType != tt.result.UpperType || + r.Valid != tt.result.Valid { + t.Errorf("%d: expected %v to decode to %v, got %v", i, tt.src, tt.result, r) + } + } +}