Skip to content

Commit

Permalink
feat: add the ability to set jwt header type (#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyongyosi authored Mar 17, 2023
1 parent daf3b15 commit 45a6785
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 25 deletions.
2 changes: 1 addition & 1 deletion token/jwt/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func NewHeaders() *Headers {

// ToMap will transform the headers to a map structure
func (h *Headers) ToMap() map[string]interface{} {
var filter = map[string]bool{"alg": true, "typ": true}
var filter = map[string]bool{"alg": true}
var extra = map[string]interface{}{}

// filter known values from extra.
Expand Down
4 changes: 3 additions & 1 deletion token/jwt/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ func (t *Token) SignedString(k interface{}) (rawToken string, err error) {

func unsignedToken(t *Token) (string, error) {
t.Header["alg"] = "none"
t.Header[string(JWTHeaderType)] = JWTHeaderTypeValue
if _, ok := t.Header[string(JWTHeaderType)]; !ok {
t.Header[string(JWTHeaderType)] = JWTHeaderTypeValue
}
hbytes, err := json.Marshal(&t.Header)
if err != nil {
return "", errorsx.WithStack(err)
Expand Down
99 changes: 76 additions & 23 deletions token/jwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,73 @@ import (
)

func TestUnsignedToken(t *testing.T) {
key := UnsafeAllowNoneSignatureType
token := NewWithClaims(SigningMethodNone, MapClaims{
"aud": "foo",
"exp": time.Now().UTC().Add(time.Hour).Unix(),
"iat": time.Now().UTC().Unix(),
"sub": "nestor",
})
rawToken, err := token.SignedString(key)
require.NoError(t, err)
require.NotEmpty(t, rawToken)
parts := strings.Split(rawToken, ".")
require.Len(t, parts, 3)
require.Empty(t, parts[2])
tk, err := jwt.ParseSigned(rawToken)
require.NoError(t, err)
require.Len(t, tk.Headers, 1)
require.Equal(t, "JWT", tk.Headers[0].ExtraHeaders[jose.HeaderKey("typ")])
var testCases = []struct {
name string
jwtHeaders map[string]interface{}
expectedType string
}{
{
name: "set JWT as 'typ' when the the type is not specified in the headers",
jwtHeaders: map[string]interface{}{},
expectedType: "JWT",
},
{
name: "'typ' set explicitly",
jwtHeaders: map[string]interface{}{"typ": "at+jwt"},
expectedType: "at+jwt",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
key := UnsafeAllowNoneSignatureType
token := NewWithClaims(SigningMethodNone, MapClaims{
"aud": "foo",
"exp": time.Now().UTC().Add(time.Hour).Unix(),
"iat": time.Now().UTC().Unix(),
"sub": "nestor",
})
token.Header = tc.jwtHeaders
rawToken, err := token.SignedString(key)
require.NoError(t, err)
require.NotEmpty(t, rawToken)
parts := strings.Split(rawToken, ".")
require.Len(t, parts, 3)
require.Empty(t, parts[2])
tk, err := jwt.ParseSigned(rawToken)
require.NoError(t, err)
require.Len(t, tk.Headers, 1)
require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[jose.HeaderKey("typ")])
})
}
}

func TestJWTHeaders(t *testing.T) {
rawToken := makeSampleToken(nil, jose.RS256, gen.MustRSAKey())
tk, err := jwt.ParseSigned(rawToken)
require.NoError(t, err)
require.Len(t, tk.Headers, 1)
require.Equal(t, tk.Headers[0].Algorithm, "RS256")
require.Equal(t, "JWT", tk.Headers[0].ExtraHeaders[jose.HeaderKey("typ")])
var testCases = []struct {
name string
jwtHeaders map[string]interface{}
expectedType string
}{
{
name: "set JWT as 'typ' when the the type is not specified in the headers",
jwtHeaders: map[string]interface{}{},
expectedType: "JWT",
},
{
name: "'typ' set explicitly",
jwtHeaders: map[string]interface{}{"typ": "at+jwt"},
expectedType: "at+jwt",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rawToken := makeSampleTokenWithCustomHeaders(nil, jose.RS256, tc.jwtHeaders, gen.MustRSAKey())
tk, err := jwt.ParseSigned(rawToken)
require.NoError(t, err)
require.Len(t, tk.Headers, 1)
require.Equal(t, tk.Headers[0].Algorithm, "RS256")
require.Equal(t, tc.expectedType, tk.Headers[0].ExtraHeaders[jose.HeaderKey("typ")])
})
}
}

var keyFuncError error = fmt.Errorf("error loading key")
Expand Down Expand Up @@ -418,6 +459,18 @@ func makeSampleToken(c MapClaims, m jose.SignatureAlgorithm, key interface{}) st
return s
}

func makeSampleTokenWithCustomHeaders(c MapClaims, m jose.SignatureAlgorithm, headers map[string]interface{}, key interface{}) string {
token := NewWithClaims(m, c)
token.Header = headers
s, e := token.SignedString(key)

if e != nil {
panic(e.Error())
}

return s
}

func parseRSAPublicKeyFromPEM(key []byte) *rsa.PublicKey {
var err error

Expand Down

0 comments on commit 45a6785

Please sign in to comment.