diff --git a/claims.go b/claims.go index b2dc6b2d..be2ab45e 100644 --- a/claims.go +++ b/claims.go @@ -73,6 +73,11 @@ func (c RegisteredClaims) Valid(opts ...validationOption) error { vErr.Errors |= ValidationErrorNotValidYet } + if !c.validateAudience(false, opts...) { + vErr.Inner = ErrTokenInvalidAudience + vErr.Errors |= ValidationErrorAudience + } + if vErr.valid() { return nil } @@ -89,10 +94,7 @@ func (c *RegisteredClaims) VerifyAudience(cmp string, req bool) bool { // VerifyExpiresAt compares the exp claim against cmp (cmp < exp). // If req is false, it will return true, if exp is unset. func (c *RegisteredClaims) VerifyExpiresAt(cmp time.Time, req bool, opts ...validationOption) bool { - validator := validator{} - for _, o := range opts { - o(&validator) - } + validator := getValidator(opts...) if c.ExpiresAt == nil { return verifyExp(nil, cmp, req, validator.leeway) } @@ -113,10 +115,7 @@ func (c *RegisteredClaims) VerifyIssuedAt(cmp time.Time, req bool) bool { // VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). // If req is false, it will return true, if nbf is unset. func (c *RegisteredClaims) VerifyNotBefore(cmp time.Time, req bool, opts ...validationOption) bool { - validator := validator{} - for _, o := range opts { - o(&validator) - } + validator := getValidator(opts...) if c.NotBefore == nil { return verifyNbf(nil, cmp, req, validator.leeway) } @@ -130,6 +129,27 @@ func (c *RegisteredClaims) VerifyIssuer(cmp string, req bool) bool { return verifyIss(c.Issuer, cmp, req) } +func (c *RegisteredClaims) validateAudience(req bool, opts ...validationOption) bool { + if len(c.Audience) == 0 { + return !req + } + + validator := getValidator(opts...) + + if validator.skipAudience { + return true + } + + // Based on my reading of https://datatracker.ietf.org/doc/html/rfc7519/#section-4.1.3 + // this should technically fail. This is left as a decision for the maintainers to alter + // the behavior as it would be a breaking change. + if validator.audience != nil { + return c.VerifyAudience(*validator.audience, true) + } + + return !req +} + // StandardClaims are a structured version of the JWT Claims Set, as referenced at // https://datatracker.ietf.org/doc/html/rfc7519#section-4. They do not follow the // specification exactly, since they were based on an earlier draft of the @@ -174,6 +194,11 @@ func (c StandardClaims) Valid(opts ...validationOption) error { vErr.Errors |= ValidationErrorNotValidYet } + if !c.validateAudience(false, opts...) { + vErr.Inner = ErrTokenInvalidAudience + vErr.Errors |= ValidationErrorAudience + } + if vErr.valid() { return nil } @@ -190,10 +215,7 @@ func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool { // VerifyExpiresAt compares the exp claim against cmp (cmp < exp). // If req is false, it will return true, if exp is unset. func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool, opts ...validationOption) bool { - validator := validator{} - for _, o := range opts { - o(&validator) - } + validator := getValidator(opts...) if c.ExpiresAt == 0 { return verifyExp(nil, time.Unix(cmp, 0), req, validator.leeway) } @@ -216,10 +238,7 @@ func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool { // VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). // If req is false, it will return true, if nbf is unset. func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool, opts ...validationOption) bool { - validator := validator{} - for _, o := range opts { - o(&validator) - } + validator := getValidator(opts...) if c.NotBefore == 0 { return verifyNbf(nil, time.Unix(cmp, 0), req, validator.leeway) } @@ -234,6 +253,27 @@ func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool { return verifyIss(c.Issuer, cmp, req) } +func (c *StandardClaims) validateAudience(req bool, opts ...validationOption) bool { + if c.Audience == "" { + return !req + } + + validator := getValidator(opts...) + + if validator.skipAudience { + return true + } + + // Based on my reading of https://datatracker.ietf.org/doc/html/rfc7519/#section-4.1.3 + // this should technically fail. This is left as a decision for the maintainers to alter + // the behavior as it would be a breaking change. + if validator.audience != nil { + return c.VerifyAudience(*validator.audience, true) + } + + return !req +} + // ----- helpers func verifyAud(aud []string, cmp string, required bool) bool { diff --git a/map_claims.go b/map_claims.go index e4a08079..c9ba700e 100644 --- a/map_claims.go +++ b/map_claims.go @@ -42,10 +42,7 @@ func (m MapClaims) VerifyExpiresAt(cmp int64, req bool, opts ...validationOption return !req } - validator := validator{} - for _, o := range opts { - o(&validator) - } + validator := getValidator(opts...) switch exp := v.(type) { case float64: @@ -99,10 +96,7 @@ func (m MapClaims) VerifyNotBefore(cmp int64, req bool, opts ...validationOption return !req } - validator := validator{} - for _, o := range opts { - o(&validator) - } + validator := getValidator(opts...) switch nbf := v.(type) { case float64: @@ -127,6 +121,28 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { return verifyIss(iss, cmp, req) } +func (m MapClaims) validateAudience(req bool, opts ...validationOption) bool { + _, ok := m["aud"] + if !ok { + return !req + } + + validator := getValidator(opts...) + + if validator.skipAudience { + return true + } + + // Based on my reading of https://datatracker.ietf.org/doc/html/rfc7519/#section-4.1.3 + // this should technically fail. This is left as a decision for the maintainers to alter + // the behavior as it would be a breaking change. + if validator.audience != nil { + return m.VerifyAudience(*validator.audience, true) + } + + return !req +} + // Valid validates time based claims "exp, iat, nbf". // There is no accounting for clock skew. // As well, if any of the above claims are not in the token, it will still @@ -153,6 +169,11 @@ func (m MapClaims) Valid(opts ...validationOption) error { vErr.Errors |= ValidationErrorNotValidYet } + if !m.validateAudience(false, opts...) { + vErr.Inner = ErrTokenInvalidAudience + vErr.Errors |= ValidationErrorAudience + } + if vErr.valid() { return nil } diff --git a/parser_option.go b/parser_option.go index a7976645..95748161 100644 --- a/parser_option.go +++ b/parser_option.go @@ -36,3 +36,17 @@ func WithLeeway(d time.Duration) ParserOption { p.validationOptions = append(p.validationOptions, withLeeway(d)) } } + +// WithAudience returns the ParserOption for specifying an expected aud member value +func WithAudience(aud string) ParserOption { + return func(p *Parser) { + p.validationOptions = append(p.validationOptions, withAudience(aud)) + } +} + +// WithoutAudienceValidation returns the ParserOption that specifies audience check should be skipped +func WithoutAudienceValidation() ParserOption { + return func(p *Parser) { + p.validationOptions = append(p.validationOptions, withoutAudienceValidation()) + } +} diff --git a/parser_test.go b/parser_test.go index e25ff0b2..d35ebe22 100644 --- a/parser_test.go +++ b/parser_test.go @@ -339,6 +339,84 @@ var jwtTestData = []struct { &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, + { + "RFC7519 Claims - single aud without validation", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{"test"}, + }, + true, + 0, + nil, + jwt.NewParser(jwt.WithoutAudienceValidation()), + jwt.SigningMethodRS256, + }, + { + "RFC7519 Claims - multiple aud without validation", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{"test", "test2"}, + }, + true, + 0, + nil, + jwt.NewParser(jwt.WithoutAudienceValidation()), + jwt.SigningMethodRS256, + }, + { + "RFC7519 Claims - single aud with valid audience", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{"test"}, + }, + true, + 0, + nil, + jwt.NewParser(jwt.WithAudience("test")), + jwt.SigningMethodRS256, + }, + { + "RFC7519 Claims - multiple aud with valid audience", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{"test", "test2"}, + }, + true, + 0, + nil, + jwt.NewParser(jwt.WithAudience("test")), + jwt.SigningMethodRS256, + }, + { + "RFC7519 Claims - single aud with invalid audience", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{"test"}, + }, + false, + jwt.ValidationErrorAudience, + []error{jwt.ErrTokenInvalidAudience}, + jwt.NewParser(jwt.WithAudience("bad")), + jwt.SigningMethodRS256, + }, + { + "RFC7519 Claims - multiple aud with invalid audience", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{"test", "test2"}, + }, + false, + jwt.ValidationErrorAudience, + []error{jwt.ErrTokenInvalidAudience}, + jwt.NewParser(jwt.WithAudience("bad")), + jwt.SigningMethodRS256, + }, { "RFC7519 Claims - single aud with wrong type", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 } diff --git a/validator_option.go b/validator_option.go index eb29dc30..a2b48154 100644 --- a/validator_option.go +++ b/validator_option.go @@ -15,7 +15,9 @@ type validationOption func(*validator) // Note that this struct is (currently) un-exported, its naming is subject to change and will only be exported once // the API is more stable. type validator struct { - leeway time.Duration // Leeway to provide when validating time values + audience *string // Expected audience value + skipAudience bool // Ignore aud check + leeway time.Duration // Leeway to provide when validating time values } // withLeeway is an option to set the clock skew (leeway) window @@ -27,3 +29,32 @@ func withLeeway(d time.Duration) validationOption { v.leeway = d } } + +// withAudience returns the ParserOption for specifying an expected aud member value +// +// Note that this function is (currently) un-exported, its naming is subject to change and will only be exported once +// the API is more stable. +func withAudience(aud string) validationOption { + return func(v *validator) { + v.audience = &aud + } +} + +// withoutAudienceValidation returns the ParserOption that specifies audience check should be skipped +// +// Note that this function is (currently) un-exported, its naming is subject to change and will only be exported once +// the API is more stable. +func withoutAudienceValidation() validationOption { + return func(v *validator) { + v.skipAudience = true + } +} + +// getValidator return the validation given the options +func getValidator(opts ...validationOption) validator { + v := validator{} + for _, o := range opts { + o(&v) + } + return v +}