Skip to content

Commit

Permalink
Improve handling of unknown fields in parameter structs (#32)
Browse files Browse the repository at this point in the history
Commit 78545e2 disallowed unknown struct fields when unmarshaling
parameters. The check could be bypassed by implementing the json.Unmarshaler
interface, but the need to do so causes friction for implementations of LSP,
which has a loose and rapidly-changing schema (see #5, #31).

This changes the default back to ignoring unknown fields, and adds a new
optional interface to re-enable strict checking without a custom unmarshaler.

N.B.: This is a breaking API change.

Relevant changes:

- Update UnmarshalParams and UnmarshalResult to check for the target having a
  DisallowUnknownFields method, and to enable the check for such targets.

- Rework the decoding to make the default path do less allocation.

- Update documentation and tests.
  • Loading branch information
creachadair committed Nov 18, 2020
1 parent ae12d65 commit 1858c65
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 37 deletions.
68 changes: 46 additions & 22 deletions base.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,30 @@ func (r *Request) HasParams() bool { return len(r.params) != 0 }
// parameters, it returns nil without modifying v. If r is invalid it returns
// an InvalidParams error.
//
// By default, unknown keys are disallowed when unmarshaling into a v of struct
// type. The caller may override this using jrpc2.NonStrict, by implementing
// json.Unmarshaler on the concrete type of v, or by unmarshaling into a
// json.RawMessage and separately decoding the result. The examples demonstrate
// how to do this.
// By default, unknown object keys are ignored when unmarshaling into a v of
// struct type. This can be overridden either by giving the type of v a custom
// implementation of json.Unmarshaler, or implementing a DisallowUnknownFields
// method. The jrpc2.StrictFields helper function adapts existing values to
// this interface.
//
// If v has type *json.RawMessage, decoding cannot fail.
func (r *Request) UnmarshalParams(v interface{}) error {
if len(r.params) == 0 {
return nil
} else if raw, ok := v.(*json.RawMessage); ok {
*raw = json.RawMessage(string(r.params)) // copy
return nil
}
dec := json.NewDecoder(bytes.NewReader(r.params))
dec.DisallowUnknownFields()
if err := dec.Decode(v); err != nil {
return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error())
switch t := v.(type) {
case *json.RawMessage:
*t = json.RawMessage(string(r.params)) // copy
return nil
case strictFielder:
dec := json.NewDecoder(bytes.NewReader(r.params))
dec.DisallowUnknownFields()
if err := dec.Decode(v); err != nil {
return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error())
}
return nil
}
return nil
return json.Unmarshal(r.params, v)
}

// ParamString returns the encoded request parameters of r as a string.
Expand Down Expand Up @@ -152,10 +156,25 @@ func (r *Response) Error() *Error { return r.err }
// UnmarshalResult decodes the result message into v. If the request failed,
// UnmarshalResult returns the *Error value that would also be returned by
// r.Error(), and v is unmodified.
//
// By default, unknown object keys are ignored when unmarshaling into a v of
// struct type. This can be overridden either by giving the type of v a custom
// implementation of json.Unmarshaler, or implementing a DisallowUnknownFields
// method. The jrpc2.StrictFields helper function adapts existing values to
// this interface.
func (r *Response) UnmarshalResult(v interface{}) error {
if r.err != nil {
return r.err
}
switch t := v.(type) {
case *json.RawMessage:
*t = json.RawMessage(string(r.result)) // copy
return nil
case strictFielder:
dec := json.NewDecoder(bytes.NewReader(r.result))
dec.DisallowUnknownFields()
return dec.Decode(v)
}
return json.Unmarshal(r.result, v)
}

Expand Down Expand Up @@ -410,18 +429,23 @@ func filterError(e *Error) error {
return e
}

// NonStrict wraps a value v so that it can be unmarshaled from JSON without
// checking for unknown fields. The v provided must itself be a valid argument
// to json.Unmarshal.
// strictFielder is an optional interface that can be implemented by a type to
// reject unknown fields when unmarshaling from JSON. If a type does not
// implement this interface, unknown fields are ignored.
type strictFielder interface {
DisallowUnknownFields()
}

// StrictFields wraps a value v to implement the DisallowUnknownFields method,
// requiring unknown fields to be rejected when unmarshaling from JSON.
//
// This can be used to unmarshal request parameters with unknown fields, for
// example:
// For example:
//
// var obj RequestType
// err := req.UnmarshalParams(jrpc2.NonStrict(&obj))
// err := req.UnmarshalParams(jrpc2.StrictFields(&obj))`
//
func NonStrict(v interface{}) json.Unmarshaler { return &nonStrict{v: v} }
func StrictFields(v interface{}) interface{} { return &strict{v: v} }

type nonStrict struct{ v interface{} }
type strict struct{ v interface{} }

func (n nonStrict) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, n.v) }
func (strict) DisallowUnknownFields() {}
14 changes: 7 additions & 7 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ func ExampleRequest_UnmarshalParams() {
B int `json:"b"`
}

// By default, unmarshaling prohibits unknown fields (here, "c").
err = reqs[0].UnmarshalParams(&t)
if code.FromError(err) != code.InvalidParams {
log.Fatalf("Expected invalid parameters, got: %v", err)
// By default, unmarshaling ignores unknown fields (here, "c").
if err := reqs[0].UnmarshalParams(&t); err != nil {
log.Fatalf("UnmarshalParams: %v", err)
}

// Solution 1: Decode with jrpc2.NonStrict.
if err := reqs[0].UnmarshalParams(jrpc2.NonStrict(&t)); err != nil {
log.Fatalf("UnmarshalParams: %v", err)
// Solution 1: Use the jrpc2.StrictFields helper.
err = reqs[0].UnmarshalParams(jrpc2.StrictFields(&t))
if code.FromError(err) != code.InvalidParams {
log.Fatalf("UnmarshalParams strict: %v", err)
}
fmt.Printf("t.A=%d, t.B=%d\n", t.A, t.B)

Expand Down
2 changes: 1 addition & 1 deletion internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestUnmarshalParams(t *testing.T) {
{`{"jsonrpc":"2.0", "id":5, "method":"Z", "params":{"x":23, "y":true}}`,
xy{X: 23, Y: true}, `{"x":23, "y":true}`, code.NoError},
{`{"jsonrpc":"2.0", "id":6, "method":"Z", "params":{"x":23, "z":"wat"}}`,
xy{}, `{"x":23, "z":"wat"}`, code.InvalidParams},
xy{X: 23}, `{"x":23, "z":"wat"}`, code.NoError},
}
for _, test := range tests {
req, err := ParseRequests([]byte(test.input))
Expand Down
44 changes: 37 additions & 7 deletions jrpc2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ func (buggyChannel) Send([]byte) error { panic("should not be called") }
func (b buggyChannel) Recv() ([]byte, error) { return []byte(b.data), b.err }
func (buggyChannel) Close() error { return nil }

func TestNonStrict(t *testing.T) {
func TestStrictFields(t *testing.T) {
type other struct {
C bool `json:"charlie"`
}
Expand All @@ -1074,30 +1074,60 @@ func TestNonStrict(t *testing.T) {
B int `json:"bravo"`
other
}
type result struct {
X string `json:"xray"`
}
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context, req *jrpc2.Request) error {
"Test": handler.New(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
var ps, qs params

if err := req.UnmarshalParams(&ps); err == nil {
if err := req.UnmarshalParams(jrpc2.StrictFields(&ps)); err == nil {
t.Errorf("Unmarshal strict: got %+v, want error", ps)
}

if err := req.UnmarshalParams(jrpc2.NonStrict(&qs)); err != nil {
t.Errorf("Unmarshal non-strict: unexpected error: %v", err)
if err := req.UnmarshalParams(&qs); err != nil {
t.Errorf("Unmarshal non-strict (default): unexpected error: %v", err)
} else {
t.Logf("Parameters OK: %+v", qs)
}

return nil
return map[string]string{
"xray": "ok",
"gamma": "not ok",
}, nil
}),
}, nil)
defer loc.Close()

ctx := context.Background()
loc.Client.Call(ctx, "Test", handler.Obj{
req := handler.Obj{
"alpha": "foo",
"bravo": 25,
"charlie": true, // exercise embedding
"delta": 31.5, // unknown field
}
t.Run("NonStrictResult", func(t *testing.T) {
rsp, err := loc.Client.Call(ctx, "Test", req)
if err != nil {
t.Fatalf("Call failed: %v", err)
}
var res result
if err := rsp.UnmarshalResult(&res); err != nil {
t.Errorf("UnmarshalResult: %v", err)
}
t.Logf("Result: %+v", res)
})

t.Run("StrictResult", func(t *testing.T) {
rsp, err := loc.Client.Call(ctx, "Test", req)
if err != nil {
t.Fatalf("Call failed: %v", err)
}
var res result
if err := rsp.UnmarshalResult(jrpc2.StrictFields(&res)); err == nil {
t.Errorf("UnmarshalResult: got %+v, want error", res)
} else {
t.Logf("UnmarshalResult: got expected error: %v", err)
}
})
}

0 comments on commit 1858c65

Please sign in to comment.