Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of unknown fields in parameter structs #32

Merged
merged 8 commits into from
Nov 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions base.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,30 @@ func (r *Request) HasParams() bool { return len(r.params) != 0 }
// parameters, it returns nil without modifying v. If r is invalid it returns
// an InvalidParams error.
//
// By default, unknown keys are disallowed when unmarshaling into a v of struct
// type. The caller may override this using jrpc2.NonStrict, by implementing
// json.Unmarshaler on the concrete type of v, or by unmarshaling into a
// json.RawMessage and separately decoding the result. The examples demonstrate
// how to do this.
// By default, unknown object keys are ignored when unmarshaling into a v of
// struct type. This can be overridden either by giving the type of v a custom
// implementation of json.Unmarshaler, or implementing a DisallowUnknownFields
// method. The jrpc2.StrictFields helper function adapts existing values to
// this interface.
//
// If v has type *json.RawMessage, decoding cannot fail.
func (r *Request) UnmarshalParams(v interface{}) error {
if len(r.params) == 0 {
return nil
} else if raw, ok := v.(*json.RawMessage); ok {
*raw = json.RawMessage(string(r.params)) // copy
return nil
}
dec := json.NewDecoder(bytes.NewReader(r.params))
dec.DisallowUnknownFields()
if err := dec.Decode(v); err != nil {
return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error())
switch t := v.(type) {
case *json.RawMessage:
*t = json.RawMessage(string(r.params)) // copy
return nil
case strictFielder:
dec := json.NewDecoder(bytes.NewReader(r.params))
dec.DisallowUnknownFields()
if err := dec.Decode(v); err != nil {
return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error())
}
return nil
}
return nil
return json.Unmarshal(r.params, v)
}

// ParamString returns the encoded request parameters of r as a string.
Expand Down Expand Up @@ -152,10 +156,25 @@ func (r *Response) Error() *Error { return r.err }
// UnmarshalResult decodes the result message into v. If the request failed,
// UnmarshalResult returns the *Error value that would also be returned by
// r.Error(), and v is unmodified.
//
// By default, unknown object keys are ignored when unmarshaling into a v of
// struct type. This can be overridden either by giving the type of v a custom
// implementation of json.Unmarshaler, or implementing a DisallowUnknownFields
// method. The jrpc2.StrictFields helper function adapts existing values to
// this interface.
func (r *Response) UnmarshalResult(v interface{}) error {
if r.err != nil {
return r.err
}
switch t := v.(type) {
case *json.RawMessage:
*t = json.RawMessage(string(r.result)) // copy
return nil
case strictFielder:
dec := json.NewDecoder(bytes.NewReader(r.result))
dec.DisallowUnknownFields()
return dec.Decode(v)
}
return json.Unmarshal(r.result, v)
}

Expand Down Expand Up @@ -410,18 +429,23 @@ func filterError(e *Error) error {
return e
}

// NonStrict wraps a value v so that it can be unmarshaled from JSON without
// checking for unknown fields. The v provided must itself be a valid argument
// to json.Unmarshal.
// strictFielder is an optional interface that can be implemented by a type to
// reject unknown fields when unmarshaling from JSON. If a type does not
// implement this interface, unknown fields are ignored.
type strictFielder interface {
DisallowUnknownFields()
}

// StrictFields wraps a value v to implement the DisallowUnknownFields method,
// requiring unknown fields to be rejected when unmarshaling from JSON.
//
// This can be used to unmarshal request parameters with unknown fields, for
// example:
// For example:
//
// var obj RequestType
// err := req.UnmarshalParams(jrpc2.NonStrict(&obj))
// err := req.UnmarshalParams(jrpc2.StrictFields(&obj))`
//
func NonStrict(v interface{}) json.Unmarshaler { return &nonStrict{v: v} }
func StrictFields(v interface{}) interface{} { return &strict{v: v} }

type nonStrict struct{ v interface{} }
type strict struct{ v interface{} }

func (n nonStrict) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, n.v) }
func (strict) DisallowUnknownFields() {}
14 changes: 7 additions & 7 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ func ExampleRequest_UnmarshalParams() {
B int `json:"b"`
}

// By default, unmarshaling prohibits unknown fields (here, "c").
err = reqs[0].UnmarshalParams(&t)
if code.FromError(err) != code.InvalidParams {
log.Fatalf("Expected invalid parameters, got: %v", err)
// By default, unmarshaling ignores unknown fields (here, "c").
if err := reqs[0].UnmarshalParams(&t); err != nil {
log.Fatalf("UnmarshalParams: %v", err)
}

// Solution 1: Decode with jrpc2.NonStrict.
if err := reqs[0].UnmarshalParams(jrpc2.NonStrict(&t)); err != nil {
log.Fatalf("UnmarshalParams: %v", err)
// Solution 1: Use the jrpc2.StrictFields helper.
err = reqs[0].UnmarshalParams(jrpc2.StrictFields(&t))
if code.FromError(err) != code.InvalidParams {
log.Fatalf("UnmarshalParams strict: %v", err)
}
fmt.Printf("t.A=%d, t.B=%d\n", t.A, t.B)

Expand Down
2 changes: 1 addition & 1 deletion internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestUnmarshalParams(t *testing.T) {
{`{"jsonrpc":"2.0", "id":5, "method":"Z", "params":{"x":23, "y":true}}`,
xy{X: 23, Y: true}, `{"x":23, "y":true}`, code.NoError},
{`{"jsonrpc":"2.0", "id":6, "method":"Z", "params":{"x":23, "z":"wat"}}`,
xy{}, `{"x":23, "z":"wat"}`, code.InvalidParams},
xy{X: 23}, `{"x":23, "z":"wat"}`, code.NoError},
}
for _, test := range tests {
req, err := ParseRequests([]byte(test.input))
Expand Down
44 changes: 37 additions & 7 deletions jrpc2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ func (buggyChannel) Send([]byte) error { panic("should not be called") }
func (b buggyChannel) Recv() ([]byte, error) { return []byte(b.data), b.err }
func (buggyChannel) Close() error { return nil }

func TestNonStrict(t *testing.T) {
func TestStrictFields(t *testing.T) {
type other struct {
C bool `json:"charlie"`
}
Expand All @@ -1074,30 +1074,60 @@ func TestNonStrict(t *testing.T) {
B int `json:"bravo"`
other
}
type result struct {
X string `json:"xray"`
}
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context, req *jrpc2.Request) error {
"Test": handler.New(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
var ps, qs params

if err := req.UnmarshalParams(&ps); err == nil {
if err := req.UnmarshalParams(jrpc2.StrictFields(&ps)); err == nil {
t.Errorf("Unmarshal strict: got %+v, want error", ps)
}

if err := req.UnmarshalParams(jrpc2.NonStrict(&qs)); err != nil {
t.Errorf("Unmarshal non-strict: unexpected error: %v", err)
if err := req.UnmarshalParams(&qs); err != nil {
t.Errorf("Unmarshal non-strict (default): unexpected error: %v", err)
} else {
t.Logf("Parameters OK: %+v", qs)
}

return nil
return map[string]string{
"xray": "ok",
"gamma": "not ok",
}, nil
}),
}, nil)
defer loc.Close()

ctx := context.Background()
loc.Client.Call(ctx, "Test", handler.Obj{
req := handler.Obj{
"alpha": "foo",
"bravo": 25,
"charlie": true, // exercise embedding
"delta": 31.5, // unknown field
}
t.Run("NonStrictResult", func(t *testing.T) {
rsp, err := loc.Client.Call(ctx, "Test", req)
if err != nil {
t.Fatalf("Call failed: %v", err)
}
var res result
if err := rsp.UnmarshalResult(&res); err != nil {
t.Errorf("UnmarshalResult: %v", err)
}
t.Logf("Result: %+v", res)
})

t.Run("StrictResult", func(t *testing.T) {
rsp, err := loc.Client.Call(ctx, "Test", req)
if err != nil {
t.Fatalf("Call failed: %v", err)
}
var res result
if err := rsp.UnmarshalResult(jrpc2.StrictFields(&res)); err == nil {
t.Errorf("UnmarshalResult: got %+v, want error", res)
} else {
t.Logf("UnmarshalResult: got expected error: %v", err)
}
})
}