diff --git a/base.go b/base.go index ebda4c8..3ea16b7 100644 --- a/base.go +++ b/base.go @@ -70,26 +70,30 @@ func (r *Request) HasParams() bool { return len(r.params) != 0 } // parameters, it returns nil without modifying v. If r is invalid it returns // an InvalidParams error. // -// By default, unknown keys are disallowed when unmarshaling into a v of struct -// type. The caller may override this using jrpc2.NonStrict, by implementing -// json.Unmarshaler on the concrete type of v, or by unmarshaling into a -// json.RawMessage and separately decoding the result. The examples demonstrate -// how to do this. +// By default, unknown object keys are ignored when unmarshaling into a v of +// struct type. This can be overridden either by giving the type of v a custom +// implementation of json.Unmarshaler, or implementing a DisallowUnknownFields +// method. The jrpc2.StrictFields helper function adapts existing values to +// this interface. // // If v has type *json.RawMessage, decoding cannot fail. func (r *Request) UnmarshalParams(v interface{}) error { if len(r.params) == 0 { return nil - } else if raw, ok := v.(*json.RawMessage); ok { - *raw = json.RawMessage(string(r.params)) // copy - return nil } - dec := json.NewDecoder(bytes.NewReader(r.params)) - dec.DisallowUnknownFields() - if err := dec.Decode(v); err != nil { - return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error()) + switch t := v.(type) { + case *json.RawMessage: + *t = json.RawMessage(string(r.params)) // copy + return nil + case strictFielder: + dec := json.NewDecoder(bytes.NewReader(r.params)) + dec.DisallowUnknownFields() + if err := dec.Decode(v); err != nil { + return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error()) + } + return nil } - return nil + return json.Unmarshal(r.params, v) } // ParamString returns the encoded request parameters of r as a string. @@ -152,10 +156,25 @@ func (r *Response) Error() *Error { return r.err } // UnmarshalResult decodes the result message into v. If the request failed, // UnmarshalResult returns the *Error value that would also be returned by // r.Error(), and v is unmodified. +// +// By default, unknown object keys are ignored when unmarshaling into a v of +// struct type. This can be overridden either by giving the type of v a custom +// implementation of json.Unmarshaler, or implementing a DisallowUnknownFields +// method. The jrpc2.StrictFields helper function adapts existing values to +// this interface. func (r *Response) UnmarshalResult(v interface{}) error { if r.err != nil { return r.err } + switch t := v.(type) { + case *json.RawMessage: + *t = json.RawMessage(string(r.result)) // copy + return nil + case strictFielder: + dec := json.NewDecoder(bytes.NewReader(r.result)) + dec.DisallowUnknownFields() + return dec.Decode(v) + } return json.Unmarshal(r.result, v) } @@ -410,18 +429,23 @@ func filterError(e *Error) error { return e } -// NonStrict wraps a value v so that it can be unmarshaled from JSON without -// checking for unknown fields. The v provided must itself be a valid argument -// to json.Unmarshal. +// strictFielder is an optional interface that can be implemented by a type to +// reject unknown fields when unmarshaling from JSON. If a type does not +// implement this interface, unknown fields are ignored. +type strictFielder interface { + DisallowUnknownFields() +} + +// StrictFields wraps a value v to implement the DisallowUnknownFields method, +// requiring unknown fields to be rejected when unmarshaling from JSON. // -// This can be used to unmarshal request parameters with unknown fields, for -// example: +// For example: // // var obj RequestType -// err := req.UnmarshalParams(jrpc2.NonStrict(&obj)) +// err := req.UnmarshalParams(jrpc2.StrictFields(&obj))` // -func NonStrict(v interface{}) json.Unmarshaler { return &nonStrict{v: v} } +func StrictFields(v interface{}) interface{} { return &strict{v: v} } -type nonStrict struct{ v interface{} } +type strict struct{ v interface{} } -func (n nonStrict) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, n.v) } +func (strict) DisallowUnknownFields() {} diff --git a/examples_test.go b/examples_test.go index 08d207f..ee1419e 100644 --- a/examples_test.go +++ b/examples_test.go @@ -114,15 +114,15 @@ func ExampleRequest_UnmarshalParams() { B int `json:"b"` } - // By default, unmarshaling prohibits unknown fields (here, "c"). - err = reqs[0].UnmarshalParams(&t) - if code.FromError(err) != code.InvalidParams { - log.Fatalf("Expected invalid parameters, got: %v", err) + // By default, unmarshaling ignores unknown fields (here, "c"). + if err := reqs[0].UnmarshalParams(&t); err != nil { + log.Fatalf("UnmarshalParams: %v", err) } - // Solution 1: Decode with jrpc2.NonStrict. - if err := reqs[0].UnmarshalParams(jrpc2.NonStrict(&t)); err != nil { - log.Fatalf("UnmarshalParams: %v", err) + // Solution 1: Use the jrpc2.StrictFields helper. + err = reqs[0].UnmarshalParams(jrpc2.StrictFields(&t)) + if code.FromError(err) != code.InvalidParams { + log.Fatalf("UnmarshalParams strict: %v", err) } fmt.Printf("t.A=%d, t.B=%d\n", t.A, t.B) diff --git a/internal_test.go b/internal_test.go index 4d0516d..04c4fe0 100644 --- a/internal_test.go +++ b/internal_test.go @@ -107,7 +107,7 @@ func TestUnmarshalParams(t *testing.T) { {`{"jsonrpc":"2.0", "id":5, "method":"Z", "params":{"x":23, "y":true}}`, xy{X: 23, Y: true}, `{"x":23, "y":true}`, code.NoError}, {`{"jsonrpc":"2.0", "id":6, "method":"Z", "params":{"x":23, "z":"wat"}}`, - xy{}, `{"x":23, "z":"wat"}`, code.InvalidParams}, + xy{X: 23}, `{"x":23, "z":"wat"}`, code.NoError}, } for _, test := range tests { req, err := ParseRequests([]byte(test.input)) diff --git a/jrpc2_test.go b/jrpc2_test.go index 907e534..31c478d 100644 --- a/jrpc2_test.go +++ b/jrpc2_test.go @@ -1065,7 +1065,7 @@ func (buggyChannel) Send([]byte) error { panic("should not be called") } func (b buggyChannel) Recv() ([]byte, error) { return []byte(b.data), b.err } func (buggyChannel) Close() error { return nil } -func TestNonStrict(t *testing.T) { +func TestStrictFields(t *testing.T) { type other struct { C bool `json:"charlie"` } @@ -1074,30 +1074,60 @@ func TestNonStrict(t *testing.T) { B int `json:"bravo"` other } + type result struct { + X string `json:"xray"` + } loc := server.NewLocal(handler.Map{ - "Test": handler.New(func(ctx context.Context, req *jrpc2.Request) error { + "Test": handler.New(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) { var ps, qs params - if err := req.UnmarshalParams(&ps); err == nil { + if err := req.UnmarshalParams(jrpc2.StrictFields(&ps)); err == nil { t.Errorf("Unmarshal strict: got %+v, want error", ps) } - if err := req.UnmarshalParams(jrpc2.NonStrict(&qs)); err != nil { - t.Errorf("Unmarshal non-strict: unexpected error: %v", err) + if err := req.UnmarshalParams(&qs); err != nil { + t.Errorf("Unmarshal non-strict (default): unexpected error: %v", err) } else { t.Logf("Parameters OK: %+v", qs) } - return nil + return map[string]string{ + "xray": "ok", + "gamma": "not ok", + }, nil }), }, nil) defer loc.Close() ctx := context.Background() - loc.Client.Call(ctx, "Test", handler.Obj{ + req := handler.Obj{ "alpha": "foo", "bravo": 25, "charlie": true, // exercise embedding "delta": 31.5, // unknown field + } + t.Run("NonStrictResult", func(t *testing.T) { + rsp, err := loc.Client.Call(ctx, "Test", req) + if err != nil { + t.Fatalf("Call failed: %v", err) + } + var res result + if err := rsp.UnmarshalResult(&res); err != nil { + t.Errorf("UnmarshalResult: %v", err) + } + t.Logf("Result: %+v", res) + }) + + t.Run("StrictResult", func(t *testing.T) { + rsp, err := loc.Client.Call(ctx, "Test", req) + if err != nil { + t.Fatalf("Call failed: %v", err) + } + var res result + if err := rsp.UnmarshalResult(jrpc2.StrictFields(&res)); err == nil { + t.Errorf("UnmarshalResult: got %+v, want error", res) + } else { + t.Logf("UnmarshalResult: got expected error: %v", err) + } }) }