diff --git a/wrphttp/errors.go b/wrphttp/errors.go index 5b52f22..ecbd9ff 100644 --- a/wrphttp/errors.go +++ b/wrphttp/errors.go @@ -16,6 +16,8 @@ package wrphttp +import "errors" + type httpError struct { err error code int @@ -28,3 +30,8 @@ func (e httpError) Error() string { func (e httpError) StatusCode() int { return e.code } + +// Is reports whether any error in e.err's chain matches target. +func (e httpError) Is(target error) bool { + return errors.Is(e.err, target) +} diff --git a/wrphttp/handler_test.go b/wrphttp/handler_test.go index c79a44d..a5256aa 100644 --- a/wrphttp/handler_test.go +++ b/wrphttp/handler_test.go @@ -20,6 +20,7 @@ package wrphttp import ( "context" "errors" + "fmt" "net/http" "net/http/httptest" "strconv" @@ -201,8 +202,9 @@ func testWRPHandlerDecodeError(t *testing.T) { assert = assert.New(t) require = require.New(t) - expectedCtx = context.WithValue(context.Background(), foo, "bar") - expectedErr = errors.New("expected") + expectedCtx = context.WithValue(context.Background(), foo, "bar") + expectedErr = errors.New("expected") + expectedHTTPStatusCode = http.StatusBadRequest decoder = func(actualCtx context.Context, httpRequest *http.Request) (*Entity, error) { assert.Equal(expectedCtx, actualCtx) @@ -213,7 +215,16 @@ func testWRPHandlerDecodeError(t *testing.T) { errorEncoder = func(actualCtx context.Context, actualErr error, _ http.ResponseWriter) { errorEncoderCalled = true assert.Equal(expectedCtx, actualCtx) - assert.Equal(expectedErr, actualErr) + assert.ErrorIs(actualErr, expectedErr, + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + actualErr, expectedErr)) + + var actualErrorHTTP httpError + if assert.ErrorAs(actualErr, &actualErrorHTTP, + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + actualErr, actualErrorHTTP)) { + assert.Equal(expectedHTTPStatusCode, actualErrorHTTP.StatusCode()) + } } wrpHandler = new(MockHandler)