Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: add support for type parameter #272

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/jwt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func verifyToken() error {
}

// Parse the token. Load the key from command line option
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(string(tokData), func(t *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
if isNone() {
return jwt.UnsafeAllowNoneSignatureType, nil
}
Expand Down
26 changes: 13 additions & 13 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func ExampleNewWithClaims_registeredClaims() {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, err := token.SignedString(mySigningKey)
fmt.Printf("%v %v", ss, err)
//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 <nil>
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 <nil>
}

// Example creating a token using a custom claims type. The RegisteredClaims is embedded
Expand Down Expand Up @@ -67,7 +67,7 @@ func ExampleNewWithClaims_customClaimsType() {
ss, err := token.SignedString(mySigningKey)
fmt.Printf("%v %v", ss, err)

//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM <nil>
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM <nil>
}

// Example creating a token using a custom claims type. The RegisteredClaims is embedded
Expand All @@ -80,12 +80,12 @@ func ExampleParseWithClaims_customClaimsType() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
if token.Valid {
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
} else {
fmt.Println(err)
}
Expand All @@ -103,12 +103,12 @@ func ExampleParseWithClaims_validationOptions() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
if token.Valid {
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
} else {
fmt.Println(err)
}
Expand Down Expand Up @@ -138,12 +138,12 @@ func (m MyCustomClaims) Validate() error {
func ExampleParseWithClaims_customValidation() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
if token.Valid {
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
} else {
fmt.Println(err)
}
Expand All @@ -154,9 +154,9 @@ func ExampleParseWithClaims_customValidation() {
// An example of parsing the error types using errors.Is.
func ExampleParse_errorChecking() {
// Token from another example. This token is expired
var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"

token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

Expand Down
11 changes: 5 additions & 6 deletions hmac_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func ExampleParse_hmac() {
// useful if you use multiple keys for your application. The standard is to use 'kid' in the
// head of the token to identify which key to use, but the parsed token (head and claims) is provided
// to the callback, providing flexibility.
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.TokenFor[jwt.MapClaims]) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
Expand All @@ -56,12 +56,11 @@ func ExampleParse_hmac() {
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
return hmacSampleSecret, nil
})

if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
fmt.Println(claims["foo"], claims["nbf"])
} else {
fmt.Println(err)
if err != nil {
panic(err)
}

fmt.Println(token.Claims["foo"], token.Claims["nbf"])

// Output: bar 1.4444784e+09
}
36 changes: 17 additions & 19 deletions http_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,20 @@ func Example_getTokenViaHTTP() {
tokenString := strings.TrimSpace(buf.String())

// Parse the token
token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.TokenFor[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
})
fatal(err)

claims := token.Claims.(*CustomClaimsExample)
claims := token.Claims
fmt.Println(claims.CustomerInfo.Name)

//Output: test
// Output: test
}

func Example_useTokenViaHTTP() {

// Make a sample token
// In a real world situation, this token will have been acquired from
// some other API call (see Example_getTokenViaHTTP)
Expand All @@ -138,18 +137,18 @@ func Example_useTokenViaHTTP() {

func createToken(user string) (string, error) {
// create a signer for rsa 256
t := jwt.New(jwt.GetSigningMethod("RS256"))

// set our claims
t.Claims = &CustomClaimsExample{
jwt.RegisteredClaims{
// set the expire time
// see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
t := jwt.NewWithClaims(
jwt.GetSigningMethod("RS256"),
&CustomClaimsExample{
jwt.RegisteredClaims{
// set the expire time
// see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
},
"level1",
CustomerInfo{user, "human"},
},
"level1",
CustomerInfo{user, "human"},
}
)

// Creat token string
return t.SignedString(signKey)
Expand Down Expand Up @@ -192,12 +191,11 @@ func authHandler(w http.ResponseWriter, r *http.Request) {
// only accessible with a valid token
func restrictedHandler(w http.ResponseWriter, r *http.Request) {
// Get token from request
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) {
token, err := request.ParseFromRequestWithClaims(r, request.OAuth2Extractor, func(token *jwt.TokenFor[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
}, request.WithClaims(&CustomClaimsExample{}))

})
// If the token is missing or invalid, return error
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
Expand All @@ -206,5 +204,5 @@ func restrictedHandler(w http.ResponseWriter, r *http.Request) {
}

// Token is valid
fmt.Fprintln(w, "Welcome,", token.Claims.(*CustomClaimsExample).Name)
fmt.Fprintln(w, "Welcome,", token.Claims.Name)
}
72 changes: 38 additions & 34 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"
)

type Parser struct {
type parserOpts struct {
// If populated, only these methods will be considered valid.
validMethods []string

Expand All @@ -20,44 +20,54 @@ type Parser struct {
validator *validator
}

type Parser[T Claims] struct {
opts parserOpts
}

// NewParser creates a new Parser with the specified options
func NewParser(options ...ParserOption) *Parser {
p := &Parser{
validator: &validator{},
func NewParser(options ...ParserOption) *Parser[MapClaims] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func NewParser(options ...ParserOption) *Parser[MapClaims] {
func NewParser(options ...ParserOption) *Parser[Claims] {

p := &Parser[MapClaims]{
opts: parserOpts{validator: &validator{}},
}

// Loop through our parsing options and apply them
for _, option := range options {
option(p)
option(&p.opts)
}

return p
}

// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the key for validating.
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
func NewParserFor[T Claims](options ...ParserOption) *Parser[T] {
p := &Parser[T]{
opts: parserOpts{validator: &validator{}},
}

// Loop through our parsing options and apply them
for _, option := range options {
option(&p.opts)
}

return p
}

// ParseWithClaims parses, validates, and verifies like Parse, but supplies a default object implementing the Claims
// interface. This provides default values which can be overridden and allows a caller to use their own type, rather
// than the default MapClaims implementation of Claims.
// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the key for validating.
//
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
token, parts, err := p.ParseUnverified(tokenString, claims)
func (p *Parser[T]) Parse(tokenString string, keyFunc KeyfuncFor[T]) (*TokenFor[T], error) {
token, parts, err := p.ParseUnverified(tokenString)
if err != nil {
return token, err
}

// Verify signing method is in the required set
if p.validMethods != nil {
var signingMethodValid = false
var alg = token.Method.Alg()
for _, m := range p.validMethods {
if p.opts.validMethods != nil {
signingMethodValid := false
alg := token.Method.Alg()
for _, m := range p.opts.validMethods {
if m == alg {
signingMethodValid = true
break
Expand Down Expand Up @@ -86,13 +96,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
}

// Validate Claims
if !p.skipClaimsValidation {
if !p.opts.skipClaimsValidation {
// Make sure we have at least a default validator
if p.validator == nil {
p.validator = newValidator()
if p.opts.validator == nil {
p.opts.validator = newValidator()
}

if err := p.validator.Validate(claims); err != nil {
if err := p.opts.validator.Validate(token.Claims); err != nil {
return token, newError("", ErrTokenInvalidClaims, err)
}
}
Expand All @@ -109,13 +119,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
//
// It's only ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
func (p *Parser[T]) ParseUnverified(tokenString string) (token *TokenFor[T], parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed)
}

token = &Token{Raw: tokenString}
token = &TokenFor[T]{Raw: tokenString}

// parse Header
var headerBytes []byte
Expand All @@ -131,23 +141,17 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke

// parse Claims
var claimBytes []byte
token.Claims = claims

if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
}

dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.useJSONNumber {
if p.opts.useJSONNumber {
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}

// Handle decode error
if err != nil {
if err = dec.Decode(&token.Claims); err != nil {
return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
}

Expand Down
Loading