diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index c02d8f9c3..67ad19fd5 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -1291,6 +1291,13 @@ func TestGH430(t *testing.T) { } } +func TestGH706(t *testing.T) { + tok := jwt.New() + if !assert.ErrorIs(t, jwt.Validate(tok, jwt.WithRequiredClaim("foo")), &jwt.RequiredClaimValidationError{}, `jwt.Validate should fail`) { + return + } +} + func TestBenHigginsByPassRegression(t *testing.T) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { diff --git a/jwt/validate.go b/jwt/validate.go index d8148d5e8..8dc9828a3 100644 --- a/jwt/validate.go +++ b/jwt/validate.go @@ -159,6 +159,19 @@ func (err *validationError) Unwrap() error { return err.error } +type RequiredClaimValidationError struct { + claim string +} + +func (err *RequiredClaimValidationError) Error() string { + return fmt.Sprintf("%q not satisfied: required claim not found", err.claim) +} + +func (err *RequiredClaimValidationError) Is(target error) bool { + _, ok := target.(*RequiredClaimValidationError) + return ok +} + var errTokenExpired = NewValidationError(fmt.Errorf(`"exp" not satisfied`)) var errInvalidIssuedAt = NewValidationError(fmt.Errorf(`"iat" not satisfied`)) var errTokenNotYetValid = NewValidationError(fmt.Errorf(`"nbf" not satisfied`)) @@ -379,7 +392,7 @@ type isRequired string func (ir isRequired) Validate(_ context.Context, t Token) ValidationError { _, ok := t.Get(string(ir)) if !ok { - return NewValidationError(fmt.Errorf(`%q not satisfied: required claim not found`, string(ir))) + return NewValidationError(&RequiredClaimValidationError{claim: string(ir)}) } return nil }