Skip to content

Commit

Permalink
Golang jwt (#612)
Browse files Browse the repository at this point in the history
* replace library for JWT creation (new lib does not pass specific error message in case of invalid EC key)

* don't serialize claims into log message; they will be in a different format with the new library

* don't serialize claims into log message with wrong iat, only log wrong value

* OIDC: new test cases for expired/not yet issued token

* replace library for JWT verification

* JWT AC: create parser once for each JWT AC

* use MapClaims to simplify code; rename claims to expectedClaims

* OIDC: JWT parser is not dependent on the OIDC config any longer, so refreshing/locking is no longer necessary

* OIDC: restrict JWT parser to supported algorithms

* changelog entry
  • Loading branch information
johakoch authored Nov 3, 2022
1 parent 0a14d7b commit 7887420
Show file tree
Hide file tree
Showing 53 changed files with 1,753 additions and 1,753 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Unreleased changes are available as `avenga/couper:edge` container.
* **Added**
* [`trim()` function](https://docs.couper.io/configuration/functions) ([#605](https://github.com/avenga/couper/pull/605))

* **Changed**
* Replaced the JWT library because the former library was no longer maintained ([#612](https://github.com/avenga/couper/pull/612))

* **Fixed**
* Aligned the evaluation of [`beta_oauth2`](https://docs.couper.io/configuration/block/oauth2_ac)/[`oidc`](https://docs.couper.io/configuration/block/oidc) `redirect_uri` to `saml` `sp_acs_url` ([#589](https://github.com/avenga/couper/pull/589))
* Proper handling of empty [`beta_oauth2`](https://docs.couper.io/configuration/block/oauth2_ac)/[`oidc`](https://docs.couper.io/configuration/block/oidc) `scope` ([#593](https://github.com/avenga/couper/pull/593))
Expand Down
2 changes: 1 addition & 1 deletion accesscontrol/jwk/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"strings"
"time"

"github.com/dgrijalva/jwt-go/v4"
"github.com/golang-jwt/jwt/v4"

"github.com/avenga/couper/config"
jsn "github.com/avenga/couper/json"
Expand Down
2 changes: 1 addition & 1 deletion accesscontrol/jwk/jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"sync"
"testing"

"github.com/dgrijalva/jwt-go/v4"
"github.com/golang-jwt/jwt/v4"

"github.com/avenga/couper/accesscontrol/jwk"
"github.com/avenga/couper/config/body"
Expand Down
113 changes: 51 additions & 62 deletions accesscontrol/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
goerrors "errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/dgrijalva/jwt-go/v4"
"github.com/golang-jwt/jwt/v4"
"github.com/hashicorp/hcl/v2"
"github.com/sirupsen/logrus"

Expand Down Expand Up @@ -45,13 +45,14 @@ type (
)

type JWT struct {
algorithms []acjwt.Algorithm
algorithm acjwt.Algorithm
claims hcl.Expression
claimsRequired []string
disablePrivateCaching bool
source JWTSource
hmacSecret []byte
name string
parser *jwt.Parser
pubKey interface{}
rolesClaim string
rolesMap map[string][]string
Expand Down Expand Up @@ -120,14 +121,14 @@ func NewJWT(options *JWTOptions) (*JWT, error) {
return nil, err
}

algorithm := acjwt.NewAlgorithm(options.Algorithm)
if algorithm == acjwt.AlgorithmUnknown {
jwtAC.algorithm = acjwt.NewAlgorithm(options.Algorithm)
if jwtAC.algorithm == acjwt.AlgorithmUnknown {
return nil, fmt.Errorf("algorithm %q is not supported", options.Algorithm)
}

jwtAC.algorithms = []acjwt.Algorithm{algorithm}
jwtAC.parser = newParser([]acjwt.Algorithm{jwtAC.algorithm})

if algorithm.IsHMAC() {
if jwtAC.algorithm.IsHMAC() {
jwtAC.hmacSecret = options.Key
return jwtAC, nil
}
Expand All @@ -151,7 +152,8 @@ func NewJWTFromJWKS(options *JWTOptions) (*JWT, error) {
return nil, fmt.Errorf("invalid JWKS")
}

jwtAC.algorithms = append(acjwt.RSAAlgorithms, acjwt.ECDSAlgorithms...)
algorithms := append(acjwt.RSAAlgorithms, acjwt.ECDSAlgorithms...)
jwtAC.parser = newParser(algorithms)
jwtAC.jwks = options.JWKS

return jwtAC, nil
Expand Down Expand Up @@ -219,7 +221,7 @@ func (j *JWT) Validate(req *http.Request) error {
return errors.JwtTokenMissing.Message("token required")
}

claims, iss, aud, err := j.getConfiguredClaims(req)
expectedClaims, err := j.getConfiguredClaims(req)
if err != nil {
return err
}
Expand All @@ -229,21 +231,16 @@ func (j *JWT) Validate(req *http.Request) error {
j.jwks.Data()
}

parser := newParser(j.algorithms, iss, aud)
token, err := parser.Parse(tokenValue, j.getValidationKey)
tokenClaims := jwt.MapClaims{}
_, err = j.parser.ParseWithClaims(tokenValue, tokenClaims, j.getValidationKey)
if err != nil {
switch err := err.(type) {
case *jwt.TokenExpiredError:
if goerrors.Is(err, jwt.ErrTokenExpired) {
return errors.JwtTokenExpired.With(err)
case *jwt.UnverfiableTokenError:
if unwrappedError := err.ErrorWrapper.Unwrap(); unwrappedError != nil {
return errors.JwtTokenInvalid.With(unwrappedError)
}
}
return errors.JwtTokenInvalid.With(err)
}

tokenClaims, err := j.validateClaims(token, claims)
err = j.validateClaims(tokenClaims, expectedClaims)
if err != nil {
return errors.JwtTokenInvalid.With(err)
}
Expand All @@ -253,7 +250,8 @@ func (j *JWT) Validate(req *http.Request) error {
if !ok {
acMap = make(map[string]interface{})
}
acMap[j.name] = tokenClaims
// treat token claims as map for context
acMap[j.name] = map[string]interface{}(tokenClaims)
ctx = context.WithValue(ctx, request.AccessControls, acMap)

log := req.Context().Value(request.LogEntry).(*logrus.Entry).WithContext(req.Context())
Expand All @@ -275,7 +273,7 @@ func (j *JWT) getValidationKey(token *jwt.Token) (interface{}, error) {
return j.jwks.GetSigKeyForToken(token)
}

switch j.algorithms[0] {
switch j.algorithm {
case acjwt.AlgorithmRSA256, acjwt.AlgorithmRSA384, acjwt.AlgorithmRSA512:
return j.pubKey, nil
case acjwt.AlgorithmECDSA256, acjwt.AlgorithmECDSA384, acjwt.AlgorithmECDSA512:
Expand All @@ -288,70 +286,69 @@ func (j *JWT) getValidationKey(token *jwt.Token) (interface{}, error) {
}

// getConfiguredClaims evaluates the expected claim values from the configuration, and especially iss and aud
func (j *JWT) getConfiguredClaims(req *http.Request) (map[string]interface{}, string, string, error) {
func (j *JWT) getConfiguredClaims(req *http.Request) (map[string]interface{}, error) {
claims := make(map[string]interface{})
var iss, aud string
if j.claims != nil {
val, verr := eval.Value(eval.ContextFromRequest(req).HCLContext(), j.claims)
if verr != nil {
return nil, "", "", verr
return nil, verr
}
claims = seetie.ValueToMap(val)

var ok bool
if issVal, exists := claims["iss"]; exists {
iss, ok = issVal.(string)
_, ok = issVal.(string)
if !ok {
return nil, "", "", errors.Configuration.Message("invalid value type, string expected (claims / iss)")
return nil, errors.Configuration.Message("invalid value type, string expected (claims / iss)")
}
}

if audVal, exists := claims["aud"]; exists {
aud, ok = audVal.(string)
_, ok = audVal.(string)
if !ok {
return nil, "", "", errors.Configuration.Message("invalid value type, string expected (claims / aud)")
return nil, errors.Configuration.Message("invalid value type, string expected (claims / aud)")
}
}
}

return claims, iss, aud, nil
return claims, nil
}

// validateClaims validates the token claims against the list of required claims and the expected claims values
func (j *JWT) validateClaims(token *jwt.Token, claims map[string]interface{}) (map[string]interface{}, error) {
var tokenClaims jwt.MapClaims
if tc, ok := token.Claims.(jwt.MapClaims); ok {
tokenClaims = tc
}

if tokenClaims == nil {
return nil, fmt.Errorf("token has no claims")
}

func (j *JWT) validateClaims(tokenClaims jwt.MapClaims, expectedClaims map[string]interface{}) error {
for _, key := range j.claimsRequired {
if _, ok := tokenClaims[key]; !ok {
return nil, fmt.Errorf("required claim is missing: " + key)
return fmt.Errorf("required claim is missing: " + key)
}
}

for k, v := range claims {
if k == "iss" || k == "aud" { // gets validated during parsing
continue
}

for k, v := range expectedClaims {
val, exist := tokenClaims[k]
if !exist {
return nil, fmt.Errorf("required claim is missing: " + k)
return fmt.Errorf("required claim is missing: " + k)
}

if k == "iss" {
if !tokenClaims.VerifyIssuer(v.(string), true) {
return errors.JwtTokenInvalid.Message("invalid issuer")
}
continue
}
if k == "aud" {
if !tokenClaims.VerifyAudience(v.(string), true) {
return errors.JwtTokenInvalid.Message("invalid audience")
}
continue
}

if val != v {
return nil, fmt.Errorf("unexpected value for claim %s, got %q, expected %q", k, val, v)
return fmt.Errorf("unexpected value for claim %s, got %q, expected %q", k, val, v)
}
}
return tokenClaims, nil
return nil
}

func (j *JWT) getGrantedPermissions(tokenClaims map[string]interface{}, log *logrus.Entry) []string {
func (j *JWT) getGrantedPermissions(tokenClaims jwt.MapClaims, log *logrus.Entry) []string {
var grantedPermissions []string

grantedPermissions = j.addPermissionsFromPermissionsClaim(tokenClaims, grantedPermissions, log)
Expand All @@ -365,7 +362,7 @@ func (j *JWT) getGrantedPermissions(tokenClaims map[string]interface{}, log *log

const warnInvalidValueMsg = "invalid %s claim value type, ignoring claim, value %#v"

func (j *JWT) addPermissionsFromPermissionsClaim(tokenClaims map[string]interface{}, permissions []string, log *logrus.Entry) []string {
func (j *JWT) addPermissionsFromPermissionsClaim(tokenClaims jwt.MapClaims, permissions []string, log *logrus.Entry) []string {
if j.permissionsClaim == "" {
return permissions
}
Expand Down Expand Up @@ -428,7 +425,7 @@ func (j *JWT) getRoleValues(rolesClaimValue interface{}, log *logrus.Entry) []st
return strings.Split(rolesString, " ")
}

func (j *JWT) addPermissionsFromRoles(tokenClaims map[string]interface{}, permissions []string, log *logrus.Entry) []string {
func (j *JWT) addPermissionsFromRoles(tokenClaims jwt.MapClaims, permissions []string, log *logrus.Entry) []string {
if j.rolesClaim == "" || j.rolesMap == nil {
return permissions
}
Expand Down Expand Up @@ -504,24 +501,16 @@ func getBearer(val string) (string, error) {
return "", fmt.Errorf("bearer required with authorization header")
}

// newParser creates a new parser with issuer/audience validation if configured via iss/aud in expected claims
func newParser(algos []acjwt.Algorithm, iss, aud string) *jwt.Parser {
// newParser creates a new parser
func newParser(algos []acjwt.Algorithm) *jwt.Parser {
var algorithms []string
for _, a := range algos {
algorithms = append(algorithms, a.String())
}
options := []jwt.ParserOption{
jwt.WithValidMethods(algorithms),
jwt.WithLeeway(time.Second),
}

if aud != "" {
options = append(options, jwt.WithAudience(aud))
} else {
options = append(options, jwt.WithoutAudienceValidation())
}
if iss != "" {
options = append(options, jwt.WithIssuer(iss))
// no equivalent in new lib
// jwt.WithLeeway(time.Second),
}

return jwt.NewParser(options...)
Expand Down
2 changes: 1 addition & 1 deletion accesscontrol/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"testing"
"time"

"github.com/dgrijalva/jwt-go/v4"
"github.com/golang-jwt/jwt/v4"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hcltest"
"github.com/sirupsen/logrus"
Expand Down
2 changes: 1 addition & 1 deletion eval/lib/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"strings"
"time"

"github.com/dgrijalva/jwt-go/v4"
"github.com/golang-jwt/jwt/v4"
"github.com/hashicorp/hcl/v2"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/function"
Expand Down
6 changes: 3 additions & 3 deletions eval/lib/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ func TestJwtSignConfigError(t *testing.T) {
`,
"MyToken",
`{"sub": "12345"}`,
"configuration error: MyToken: invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key",
"configuration error: MyToken: invalid key: Key must be a PEM encoded PKCS1 or PKCS8 key",
},
{
"jwt / missing signing key or key_file",
Expand Down Expand Up @@ -812,7 +812,7 @@ func TestJwtSignConfigError(t *testing.T) {
`,
"MySelfSignedToken",
`{"sub": "12345"}`,
"configuration error: MySelfSignedToken: invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key",
"configuration error: MySelfSignedToken: invalid key: Key must be a PEM encoded PKCS1 or PKCS8 key",
},
{
"user-defined alg header",
Expand Down Expand Up @@ -984,7 +984,7 @@ func TestJwtSignError(t *testing.T) {
`,
"MyToken",
`{"sub":"12345"}`,
"key is invalid: CurveBits in public key don't match those in signing method",
"key is invalid",
},
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.19
require (
github.com/agext/levenshtein v1.2.3 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1
github.com/docker/go-units v0.4.0
github.com/fatih/color v1.13.0
github.com/getkin/kin-openapi v0.92.0
Expand Down Expand Up @@ -42,6 +41,7 @@ require (

require (
github.com/algolia/algoliasearch-client-go/v3 v3.26.0
github.com/golang-jwt/jwt/v4 v4.4.2
github.com/google/go-cmp v0.5.7
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 h1:CaO/zOnF8VvUfEbhRatPcwKVWamvbYd8tQGRWacE9kU=
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1/go.mod h1:+hnT3ywWDTAFrW5aE+u2Sa/wT555ZqwoCS+pk3p6ry4=
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
Expand Down Expand Up @@ -121,6 +119,8 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me
github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68=
github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs=
github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down
Loading

0 comments on commit 7887420

Please sign in to comment.