diff --git a/README.md b/README.md index 32734c2..6b19dab 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,13 @@ jwtware.New(config ...jwtware.Config) func(*fiber.Ctx) error | Claims | `jwt.Claim` | Claims are extendable claims data defining token content. | `jwt.MapClaims{}` | | TokenLookup | `string` | TokenLookup is a string in the form of `:` that is used | `"header:Authorization"` | | AuthScheme | `string` |AuthScheme to be used in the Authorization header. | `"Bearer"` | +| KeySetURL | `string` |KeySetURL location of JSON file with signing keys. | `""` | +| KeyRefreshSuccessHandler | `func(j *KeySet)` |KeyRefreshSuccessHandler defines a function which is executed for a valid refresh of signing keys.| `nil` | +| KeyRefreshErrorHandler | `func(j *KeySet, err error)` |KeyRefreshErrorHandler defines a function which is executed for an invalid refresh of signing keys. | `nil` | +| KeyRefreshInterval | `*time.Duration` |KeyRefreshInterval is the duration to refresh the JWKs in the background via a new HTTP request. | `nil` | +| KeyRefreshRateLimit | `*time.Duration` |KeyRefreshRateLimit limits the rate at which refresh requests are granted. | `nil` | +| KeyRefreshTimeout | `*time.Duration` |KeyRefreshTimeout is the duration for the context used to create the HTTP request for a refresh of the JWKs. | `1min` | +| KeyRefreshUnknownKID | `bool` |KeyRefreshUnknownKID indicates that the JWKs refresh request will occur every time a kid that isn't cached is seen. | `false` | ### HS256 Example @@ -233,3 +240,6 @@ func restricted(c *fiber.Ctx) error { ### RS256 Test The RS256 is actually identical to the HS256 test above. + +### JWKs Test +The tests are identical to basic `JWT` tests above, with exception that `KeySetURL` to valid public keys collection in JSON format should be supplied. diff --git a/config.go b/config.go new file mode 100644 index 0000000..bcc01b2 --- /dev/null +++ b/config.go @@ -0,0 +1,179 @@ +package jwtware + +import ( + "strings" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt/v4" +) + +// KeyRefreshSuccessHandler is a function signature that consumes a set of signing key set. +// Presence of original signing key set allows to update configuration or stop background refresh. +type KeyRefreshSuccessHandler func(j *KeySet) + +// KeyRefreshErrorHandler is a function signature that consumes a set of signing key set and an error. +// Presence of original signing key set allows to update configuration or stop background refresh. +type KeyRefreshErrorHandler func(j *KeySet, err error) + +// Config defines the config for JWT middleware +type Config struct { + // Filter defines a function to skip middleware. + // Optional. Default: nil + Filter func(*fiber.Ctx) bool + + // SuccessHandler defines a function which is executed for a valid token. + // Optional. Default: nil + SuccessHandler fiber.Handler + + // ErrorHandler defines a function which is executed for an invalid token. + // It may be used to define a custom JWT error. + // Optional. Default: 401 Invalid or expired JWT + ErrorHandler fiber.ErrorHandler + + // Signing key to validate token. Used as fallback if SigningKeys has length 0. + // Required. This, SigningKeys or KeySetUrl. + SigningKey interface{} + + // Map of signing keys to validate token with kid field usage. + // Required. This, SigningKey or KeySetUrl. + SigningKeys map[string]interface{} + + // URL where set of private keys could be downloaded. + // Required. This, SigningKey or SigningKeys. + KeySetURL string + + // KeyRefreshSuccessHandler defines a function which is executed on successful refresh of key set. + // Optional. Default: nil + KeyRefreshSuccessHandler KeyRefreshSuccessHandler + + // KeyRefreshErrorHandler defines a function which is executed for refresh key set failure. + // Optional. Default: nil + KeyRefreshErrorHandler KeyRefreshErrorHandler + + // KeyRefreshInterval is the duration to refresh the JWKs in the background via a new HTTP request. If this is not nil, + // then a background refresh will be requested in a separate goroutine at this interval until the JWKs method + // EndBackground is called. + // Optional. If set, the value will be used only if `KeySetUrl` is also present + KeyRefreshInterval *time.Duration + + // KeyRefreshRateLimit limits the rate at which refresh requests are granted. Only one refresh request can be queued + // at a time any refresh requests received while there is already a queue are ignored. It does not make sense to + // have RefreshInterval's value shorter than this. + // Optional. If set, the value will be used only if `KeySetUrl` is also present + KeyRefreshRateLimit *time.Duration + + // KeyRefreshTimeout is the duration for the context used to create the HTTP request for a refresh of the JWKs. This + // defaults to one minute. This is only effectual if RefreshInterval is not nil. + // Optional. If set, the value will be used only if `KeySetUrl` is also present + KeyRefreshTimeout *time.Duration + + // KeyRefreshUnknownKID indicates that the JWKs refresh request will occur every time a kid that isn't cached is seen. + // Without specifying a RefreshInterval a malicious client could self-sign X JWTs, send them to this service, + // then cause potentially high network usage proportional to X. + // Optional. If set, the value will be used only if `KeySetUrl` is also present + KeyRefreshUnknownKID *bool + + // Signing method, used to check token signing method. + // Optional. Default: "HS256". + // Possible values: "HS256", "HS384", "HS512", "ES256", "ES384", "ES512", "RS256", "RS384", "RS512" + SigningMethod string + + // Context key to store user information from the token into context. + // Optional. Default: "user". + ContextKey string + + // Claims are extendable claims data defining token content. + // Optional. Default value jwt.MapClaims + Claims jwt.Claims + + // TokenLookup is a string in the form of ":" that is used + // to extract token from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" + // - "query:" + // - "param:" + // - "cookie:" + TokenLookup string + + // AuthScheme to be used in the Authorization header. + // Optional. Default: "Bearer". + AuthScheme string + + keyFunc jwt.Keyfunc +} + +// makeCfg function will check correctness of supplied configuration +// and will complement it with default values instead of missing ones +func makeCfg(config []Config) (cfg Config) { + if len(config) > 0 { + cfg = config[0] + } + if cfg.SuccessHandler == nil { + cfg.SuccessHandler = func(c *fiber.Ctx) error { + return c.Next() + } + } + if cfg.ErrorHandler == nil { + cfg.ErrorHandler = func(c *fiber.Ctx, err error) error { + if err.Error() == "Missing or malformed JWT" { + return c.Status(fiber.StatusBadRequest).SendString("Missing or malformed JWT") + } + return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired JWT") + } + } + if cfg.SigningKey == nil && len(cfg.SigningKeys) == 0 && cfg.KeySetURL == "" { + panic("Fiber: JWT middleware requires signing key or url where to download one") + } + if cfg.SigningMethod == "" && cfg.KeySetURL == "" { + cfg.SigningMethod = "HS256" + } + if cfg.ContextKey == "" { + cfg.ContextKey = "user" + } + if cfg.Claims == nil { + cfg.Claims = jwt.MapClaims{} + } + if cfg.TokenLookup == "" { + cfg.TokenLookup = defaultTokenLookup + } + if cfg.AuthScheme == "" { + cfg.AuthScheme = "Bearer" + } + if cfg.KeyRefreshTimeout == nil { + cfg.KeyRefreshTimeout = &defaultKeyRefreshTimeout + } + if cfg.KeySetURL != "" { + jwks := &KeySet{ + Config: &cfg, + } + cfg.keyFunc = jwks.keyFunc() + } else { + cfg.keyFunc = jwtKeyFunc(cfg) + } + return cfg +} + +// getExtractors function will create a slice of functions which will be used +// for token sarch and will perform extraction of the value +func (cfg *Config) getExtractors() []jwtExtractor { + // Initialize + extractors := make([]jwtExtractor, 0) + rootParts := strings.Split(cfg.TokenLookup, ",") + for _, rootPart := range rootParts { + parts := strings.Split(strings.TrimSpace(rootPart), ":") + + switch parts[0] { + case "header": + extractors = append(extractors, jwtFromHeader(parts[1], cfg.AuthScheme)) + case "query": + extractors = append(extractors, jwtFromQuery(parts[1])) + case "param": + extractors = append(extractors, jwtFromParam(parts[1])) + case "cookie": + extractors = append(extractors, jwtFromCookie(parts[1])) + } + } + return extractors +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..740da39 --- /dev/null +++ b/config_test.go @@ -0,0 +1,84 @@ +package jwtware + +import ( + "testing" +) + +func TestPanicOnMissingConfiguration(t *testing.T) { + t.Parallel() + + defer func() { + // Assert + if err := recover(); err == nil { + t.Fatalf("Middleware should panic on missing configuration") + } + }() + + // Arrange + config := make([]Config, 0) + + // Act + makeCfg(config) +} + +func TestDefaultConfiguration(t *testing.T) { + t.Parallel() + + defer func() { + // Assert + if err := recover(); err != nil { + t.Fatalf("Middleware should not panic") + } + }() + + // Arrange + config := append(make([]Config, 0), Config{ + SigningKey: "", + }) + + // Act + cfg := makeCfg(config) + + // Assert + if cfg.SigningMethod != HS256 { + t.Fatalf("Default signing method should be 'HS256'") + } + if cfg.ContextKey != "user" { + t.Fatalf("Default context key should be 'user'") + } + if cfg.Claims == nil { + t.Fatalf("Default claims should not be 'nil'") + } + + if cfg.TokenLookup != defaultTokenLookup { + t.Fatalf("Default token lookup should be '%v'", defaultTokenLookup) + } + if cfg.AuthScheme != "Bearer" { + t.Fatalf("Default auth scheme should be 'Bearer'") + } +} + +func TestExtractorsInitialization(t *testing.T) { + t.Parallel() + + defer func() { + // Assert + if err := recover(); err != nil { + t.Fatalf("Middleware should not panic") + } + }() + + // Arrange + cfg := Config{ + SigningKey: "", + TokenLookup: defaultTokenLookup + ",query:token,param:token,cookie:token,something:something", + } + + // Act + extractors := cfg.getExtractors() + + // Assert + if len(extractors) != 4 { + t.Fatalf("Extractors should not be created for invalid lookups") + } +} diff --git a/crypto.go b/crypto.go new file mode 100644 index 0000000..0f58d0c --- /dev/null +++ b/crypto.go @@ -0,0 +1,165 @@ +package jwtware + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "fmt" + "math/big" +) + +const ( + // HS256 represents a public cryptography key generated by a 256 bit HMAC algorithm. + HS256 = "HS256" + + // HS384 represents a public cryptography key generated by a 384 bit HMAC algorithm. + HS384 = "HS384" + + // HS512 represents a public cryptography key generated by a 512 bit HMAC algorithm. + HS512 = "HS512" + + // ES256 represents a public cryptography key generated by a 256 bit ECDSA algorithm. + ES256 = "ES256" + + // ES384 represents a public cryptography key generated by a 384 bit ECDSA algorithm. + ES384 = "ES384" + + // ES512 represents a public cryptography key generated by a 512 bit ECDSA algorithm. + ES512 = "ES512" + + // P256 represents a cryptographic elliptical curve type. + P256 = "P-256" + + // P384 represents a cryptographic elliptical curve type. + P384 = "P-384" + + // P521 represents a cryptographic elliptical curve type. + P521 = "P-521" + + // RS256 represents a public cryptography key generated by a 256 bit RSA algorithm. + RS256 = "RS256" + + // RS384 represents a public cryptography key generated by a 384 bit RSA algorithm. + RS384 = "RS384" + + // RS512 represents a public cryptography key generated by a 512 bit RSA algorithm. + RS512 = "RS512" + + // PS256 represents a public cryptography key generated by a 256 bit RSA algorithm. + PS256 = "PS256" + + // PS384 represents a public cryptography key generated by a 384 bit RSA algorithm. + PS384 = "PS384" + + // PS512 represents a public cryptography key generated by a 512 bit RSA algorithm. + PS512 = "PS512" +) + +// getECDSA parses a JSONKey and turns it into an ECDSA public key. +func (j *rawJWK) getECDSA() (publicKey *ecdsa.PublicKey, err error) { + // Check if the key has already been computed. + if j.precomputed != nil { + var ok bool + if publicKey, ok = j.precomputed.(*ecdsa.PublicKey); ok { + return publicKey, nil + } + } + + // Confirm everything needed is present. + if j.X == "" || j.Y == "" || j.Curve == "" { + return nil, fmt.Errorf("%w: ecdsa", errMissingAssets) + } + + // Decode the X coordinate from Base64. + // + // According to RFC 7518, this is a Base64 URL unsigned integer. + // https://tools.ietf.org/html/rfc7518#section-6.3 + var xCoordinate []byte + if xCoordinate, err = base64.RawURLEncoding.DecodeString(j.X); err != nil { + return nil, err + } + + // Decode the Y coordinate from Base64. + var yCoordinate []byte + if yCoordinate, err = base64.RawURLEncoding.DecodeString(j.Y); err != nil { + return nil, err + } + + // Create the ECDSA public key. + publicKey = &ecdsa.PublicKey{} + + // Set the curve type. + var curve elliptic.Curve + switch j.Curve { + case P256: + curve = elliptic.P256() + case P384: + curve = elliptic.P384() + case P521: + curve = elliptic.P521() + } + publicKey.Curve = curve + + // Turn the X coordinate into *big.Int. + // + // According to RFC 7517, these numbers are in big-endian format. + // https://tools.ietf.org/html/rfc7517#appendix-A.1 + publicKey.X = big.NewInt(0).SetBytes(xCoordinate) + + // Turn the Y coordinate into a *big.Int. + publicKey.Y = big.NewInt(0).SetBytes(yCoordinate) + + // Keep the public key so it won't have to be computed every time. + j.precomputed = publicKey + + return publicKey, nil +} + +// getRSA parses a JSONKey and turns it into an RSA public key. +func (j *rawJWK) getRSA() (publicKey *rsa.PublicKey, err error) { + // Check if the key has already been computed. + if j.precomputed != nil { + var ok bool + if publicKey, ok = j.precomputed.(*rsa.PublicKey); ok { + return publicKey, nil + } + } + + // Confirm everything needed is present. + if j.Exponent == "" || j.Modulus == "" { + return nil, fmt.Errorf("%w: rsa", errMissingAssets) + } + + // Decode the exponent from Base64. + // + // According to RFC 7518, this is a Base64 URL unsigned integer. + // https://tools.ietf.org/html/rfc7518#section-6.3 + var exponent []byte + if exponent, err = base64.RawURLEncoding.DecodeString(j.Exponent); err != nil { + return nil, err + } + + // Decode the modulus from Base64. + var modulus []byte + if modulus, err = base64.RawURLEncoding.DecodeString(j.Modulus); err != nil { + return nil, err + } + + // Create the RSA public key. + publicKey = &rsa.PublicKey{} + + // Turn the exponent into an integer. + // + // According to RFC 7517, these numbers are in big-endian format. + // https://tools.ietf.org/html/rfc7517#appendix-A.1 + publicKey.E = int(big.NewInt(0).SetBytes(exponent).Uint64()) + + // Turn the modulus into a *big.Int. + publicKey.N = big.NewInt(0).SetBytes(modulus) + + // Keep the public key so it won't have to be computed every time. + j.precomputed = publicKey + + return publicKey, nil +} diff --git a/jwks.go b/jwks.go new file mode 100644 index 0000000..b952cdd --- /dev/null +++ b/jwks.go @@ -0,0 +1,336 @@ +package jwtware + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "sync" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +var ( // ErrKID indicates that the JWT had an invalid kid. + errMissingKeySet = errors.New("not able to download JWKs") + + // errKID indicates that the JWT had an invalid kid. + errKID = errors.New("the JWT has an invalid kid") + + // errUnsupportedKeyType indicates the JWT key type is an unsupported type. + errUnsupportedKeyType = errors.New("the JWT key type is unsupported") + + // errKIDNotFound indicates that the given key ID was not found in the JWKs. + errKIDNotFound = errors.New("the given key ID was not found in the JWKs") + + // errMissingAssets indicates there are required assets missing to create a public key. + errMissingAssets = errors.New("required assets are missing to create a public key") +) + +// rawJWK represents a raw key inside a JWKs. +type rawJWK struct { + Curve string `json:"crv"` + Exponent string `json:"e"` + ID string `json:"kid"` + Modulus string `json:"n"` + X string `json:"x"` + Y string `json:"y"` + precomputed interface{} +} + +// rawJWKs represents a JWKs in JSON format. +type rawJWKs struct { + Keys []rawJWK `json:"keys"` +} + +// KeySet represents a JSON Web Key Set. +type KeySet struct { + Keys map[string]*rawJWK + Config *Config + cancel context.CancelFunc + client *http.Client + ctx context.Context + mux sync.RWMutex + refreshRequests chan context.CancelFunc +} + +// keyFunc is a compatibility function that matches the signature of github.com/dgrijalva/jwt-go's keyFunc function. +func (j *KeySet) keyFunc() jwt.Keyfunc { + return func(token *jwt.Token) (interface{}, error) { + if j.Keys == nil { + err := j.downloadKeySet() + if err != nil { + return nil, fmt.Errorf("%w: key set URL is not accessible", errMissingKeySet) + } + } + + // Get the kid from the token header. + kidInter, ok := token.Header["kid"] + if !ok { + return nil, fmt.Errorf("%w: could not find kid in JWT header", errKID) + } + kid, ok := kidInter.(string) + if !ok { + return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", errKID) + } + + // Get the JSONKey. + jsonKey, err := j.getKey(kid) + if err != nil { + return nil, err + } + + // Determine the key's algorithm and return the appropriate public key. + switch keyAlg := token.Header["alg"]; keyAlg { + case ES256, ES384, ES512: + return jsonKey.getECDSA() + case PS256, PS384, PS512, RS256, RS384, RS512: + return jsonKey.getRSA() + default: + return nil, fmt.Errorf("%w: %s: feel free to add a feature request or contribute to https://github.com/MicahParks/keyfunc", errUnsupportedKeyType, keyAlg) + } + } +} + +// downloadKeySet loads the JWKs at the given URL. +func (j *KeySet) downloadKeySet() (err error) { + // Apply some defaults if options were not provided. + if j.client == nil { + j.client = http.DefaultClient + } + + // Get the keys for the JWKs. + if err = j.refresh(); err != nil { + return err + } + + // Check to see if a background refresh of the JWKs should happen. + if j.Config.KeyRefreshInterval != nil || j.Config.KeyRefreshRateLimit != nil { + // Attach a context used to end the background goroutine. + j.ctx, j.cancel = context.WithCancel(context.Background()) + + // Create a channel that will accept requests to refresh the JWKs. + j.refreshRequests = make(chan context.CancelFunc, 1) + + // Start the background goroutine for data refresh. + go j.startRefreshing() + } + + return nil +} + +// New creates a new JWKs from a raw JSON message. +func parseKeySet(jwksBytes json.RawMessage) (keys map[string]*rawJWK, err error) { + // Turn the raw JWKs into the correct Go type. + var rawKS rawJWKs + if err = json.Unmarshal(jwksBytes, &rawKS); err != nil { + return nil, err + } + + // Iterate through the keys in the raw JWKs. Add them to the JWKs. + keys = make(map[string]*rawJWK, len(rawKS.Keys)) + for _, key := range rawKS.Keys { + key := key + keys[key.ID] = &key + } + + return keys, nil +} + +// getKey gets the JSONKey from the given KID from the JWKs. It may refresh the JWKs if configured to. +func (j *KeySet) getKey(kid string) (jsonKey *rawJWK, err error) { + // Get the JSONKey from the JWKs. + var ok bool + j.mux.RLock() + jsonKey, ok = j.Keys[kid] + j.mux.RUnlock() + + // Check if the key was present. + if !ok { + // Check to see if configured to refresh on unknown kid. + if *j.Config.KeyRefreshUnknownKID { + // Create a context for refreshing the JWKs. + ctx, cancel := context.WithCancel(j.ctx) + + // Refresh the JWKs. + select { + case <-j.ctx.Done(): + return + case j.refreshRequests <- cancel: + default: + + // If the j.refreshRequests channel is full, return the error early. + return nil, errKIDNotFound + } + + // Wait for the JWKs refresh to done. + <-ctx.Done() + + // Lock the JWKs for async safe use. + j.mux.RLock() + defer j.mux.RUnlock() + + // Check if the JWKs refresh contained the requested key. + if jsonKey, ok = j.Keys[kid]; ok { + return jsonKey, nil + } + } + + return nil, errKIDNotFound + } + + return jsonKey, nil +} + +// startRefreshing is meant to be a separate goroutine that will update the keys in a JWKs over a given interval of +// time. +func (j *KeySet) startRefreshing() { + // Create some rate limiting assets. + var lastRefresh time.Time + var queueOnce sync.Once + var refreshMux sync.Mutex + if j.Config.KeyRefreshRateLimit != nil { + lastRefresh = time.Now().Add(-*j.Config.KeyRefreshRateLimit) + } + + // Create a channel that will never send anything unless there is a refresh interval. + refreshInterval := make(<-chan time.Time) + + // Enter an infinite loop that ends when the background ends. + for { + // If there is a refresh interval, create the channel for it. + if j.Config.KeyRefreshInterval != nil { + refreshInterval = time.After(*j.Config.KeyRefreshInterval) + } + + // Wait for a refresh to occur or the background to end. + select { + + // Send a refresh request the JWKs after the given interval. + case <-refreshInterval: + select { + case <-j.ctx.Done(): + return + case j.refreshRequests <- func() {}: + default: // If the j.refreshRequests channel is full, don't don't send another request. + } + + // Accept refresh requests. + case cancel := <-j.refreshRequests: + // Rate limit, if needed. + refreshMux.Lock() + if j.Config.KeyRefreshRateLimit != nil && lastRefresh.Add(*j.Config.KeyRefreshRateLimit).After(time.Now()) { + // Don't make the JWT parsing goroutine wait for the JWKs to refresh. + cancel() + + // Only queue a refresh once. + queueOnce.Do(func() { + + // Launch a goroutine that will get a reservation for a JWKs refresh or fail to and immediately return. + go func() { + // Wait for the next time to refresh. + refreshMux.Lock() + wait := time.Until(lastRefresh.Add(*j.Config.KeyRefreshRateLimit)) + refreshMux.Unlock() + select { + case <-j.ctx.Done(): + return + case <-time.After(wait): + } + + // Refresh the JWKs. + refreshMux.Lock() + defer refreshMux.Unlock() + if err := j.refresh(); err != nil && j.Config.KeyRefreshErrorHandler != nil { + j.Config.KeyRefreshErrorHandler(j, err) + } else if err == nil && j.Config.KeyRefreshSuccessHandler != nil { + j.Config.KeyRefreshSuccessHandler(j) + } + + // Reset the last time for the refresh to now. + lastRefresh = time.Now() + + // Allow another queue. + queueOnce = sync.Once{} + }() + }) + } else { + // Refresh the JWKs. + if err := j.refresh(); err != nil && j.Config.KeyRefreshErrorHandler != nil { + j.Config.KeyRefreshErrorHandler(j, err) + } else if err == nil && j.Config.KeyRefreshSuccessHandler != nil { + j.Config.KeyRefreshSuccessHandler(j) + } + + // Reset the last time for the refresh to now. + lastRefresh = time.Now() + + // Allow the JWT parsing goroutine to continue with the refreshed JWKs. + cancel() + } + refreshMux.Unlock() + + // Clean up this goroutine when its context expires. + case <-j.ctx.Done(): + return + } + } +} + +// refresh does an HTTP GET on the JWKs URL to rebuild the JWKs. +func (j *KeySet) refresh() (err error) { + // Create a context for the request. + var ctx context.Context + var cancel context.CancelFunc + if j.ctx != nil { + ctx, cancel = context.WithTimeout(j.ctx, *j.Config.KeyRefreshTimeout) + } else { + ctx, cancel = context.WithTimeout(context.Background(), *j.Config.KeyRefreshTimeout) + } + defer cancel() + + // Create the HTTP request. + var req *http.Request + if req, err = http.NewRequestWithContext(ctx, http.MethodGet, j.Config.KeySetURL, bytes.NewReader(nil)); err != nil { + return err + } + + // Get the JWKs as JSON from the given URL. + var resp *http.Response + if resp, err = j.client.Do(req); err != nil { + return err + } + defer resp.Body.Close() // Ignore any error. + + // Read the raw JWKs from the body of the response. + var jwksBytes []byte + if jwksBytes, err = ioutil.ReadAll(resp.Body); err != nil { + return err + } + + // Create an updated JWKs. + var keys map[string]*rawJWK + if keys, err = parseKeySet(jwksBytes); err != nil { + return err + } + + // Lock the JWKs for async safe usage. + j.mux.Lock() + defer j.mux.Unlock() + + // Update the keys. + j.Keys = keys + + return nil +} + +// StopRefreshing ends the background goroutine to update the JWKs. It can only happen once and is only effective if the +// JWKs has a background goroutine refreshing the JWKs keys. +func (j *KeySet) StopRefreshing() { + if j.cancel != nil { + j.cancel() + } +} diff --git a/jwt.go b/jwt.go index a47c71f..93f31a0 100644 --- a/jwt.go +++ b/jwt.go @@ -1,178 +1,32 @@ -// 🚀 Fiber is an Express inspired web framework written in Go with 💖 -// 📌 API Documentation: https://fiber.wiki -// 📝 Github Repository: https://github.com/gofiber/fiber -// Special thanks to Echo: https://github.com/labstack/echo/blob/master/middleware/jwt.go - package jwtware import ( "errors" "fmt" - "reflect" "strings" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v4" ) -// Config defines the config for BasicAuth middleware -type Config struct { - // Filter defines a function to skip middleware. - // Optional. Default: nil - Filter func(*fiber.Ctx) bool - - // SuccessHandler defines a function which is executed for a valid token. - // Optional. Default: nil - SuccessHandler fiber.Handler - - // ErrorHandler defines a function which is executed for an invalid token. - // It may be used to define a custom JWT error. - // Optional. Default: 401 Invalid or expired JWT - ErrorHandler fiber.ErrorHandler - - // Signing key to validate token. Used as fallback if SigningKeys has length 0. - // Required. This or SigningKeys. - SigningKey interface{} - - // Map of signing keys to validate token with kid field usage. - // Required. This or SigningKey. - SigningKeys map[string]interface{} - - // Signing method, used to check token signing method. - // Optional. Default: "HS256". - // Possible values: "HS256", "HS384", "HS512", "ES256", "ES384", "ES512", "RS256", "RS384", "RS512" - SigningMethod string - - // Context key to store user information from the token into context. - // Optional. Default: "user". - ContextKey string - - // Claims are extendable claims data defining token content. - // Optional. Default value jwt.MapClaims - Claims jwt.Claims - - // TokenLookup is a string in the form of ":" that is used - // to extract token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" - // - "query:" - // - "param:" - // - "cookie:" - TokenLookup string - - // AuthScheme to be used in the Authorization header. - // Optional. Default: "Bearer". - AuthScheme string +type jwtExtractor func(c *fiber.Ctx) (string, error) - keyFunc jwt.Keyfunc -} - -// New ... -func New(config ...Config) fiber.Handler { - // Init config - var cfg Config - if len(config) > 0 { - cfg = config[0] - } - if cfg.SuccessHandler == nil { - cfg.SuccessHandler = func(c *fiber.Ctx) error { - return c.Next() - } - } - if cfg.ErrorHandler == nil { - cfg.ErrorHandler = func(c *fiber.Ctx, err error) error { - if err.Error() == "Missing or malformed JWT" { - return c.Status(fiber.StatusBadRequest).SendString("Missing or malformed JWT") - } else { - return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired JWT") - } - } - } - if cfg.SigningKey == nil && len(cfg.SigningKeys) == 0 { - panic("Fiber: JWT middleware requires signing key") - } - if cfg.SigningMethod == "" { - cfg.SigningMethod = "HS256" - } - if cfg.ContextKey == "" { - cfg.ContextKey = "user" - } - if cfg.Claims == nil { - cfg.Claims = jwt.MapClaims{} - } - if cfg.TokenLookup == "" { - cfg.TokenLookup = "header:" + fiber.HeaderAuthorization - } - if cfg.AuthScheme == "" { - cfg.AuthScheme = "Bearer" - } - cfg.keyFunc = func(t *jwt.Token) (interface{}, error) { +// jwtKeyFunc returns a function that returns signing key for given token. +func jwtKeyFunc(config Config) jwt.Keyfunc { + return func(t *jwt.Token) (interface{}, error) { // Check the signing method - if t.Method.Alg() != cfg.SigningMethod { + if t.Method.Alg() != config.SigningMethod { return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"]) } - if len(cfg.SigningKeys) > 0 { + if len(config.SigningKeys) > 0 { if kid, ok := t.Header["kid"].(string); ok { - if key, ok := cfg.SigningKeys[kid]; ok { + if key, ok := config.SigningKeys[kid]; ok { return key, nil } } return nil, fmt.Errorf("Unexpected jwt key id=%v", t.Header["kid"]) } - return cfg.SigningKey, nil - } - // Initialize - extractors := make([]func(c *fiber.Ctx) (string, error), 0) - rootParts := strings.Split(cfg.TokenLookup, ",") - for _, rootPart := range rootParts { - parts := strings.Split(strings.TrimSpace(rootPart), ":") - - switch parts[0] { - case "header": - extractors = append(extractors, jwtFromHeader(parts[1], cfg.AuthScheme)) - case "query": - extractors = append(extractors, jwtFromQuery(parts[1])) - case "param": - extractors = append(extractors, jwtFromParam(parts[1])) - case "cookie": - extractors = append(extractors, jwtFromCookie(parts[1])) - } - } - // Return middleware handler - return func(c *fiber.Ctx) error { - // Filter request to skip middleware - if cfg.Filter != nil && cfg.Filter(c) { - return c.Next() - } - var auth string - var err error - - for _, extractor := range extractors { - auth, err = extractor(c) - if auth != "" && err == nil { - break - } - } - - if err != nil { - return cfg.ErrorHandler(c, err) - } - token := new(jwt.Token) - - if _, ok := cfg.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, cfg.keyFunc) - } else { - t := reflect.ValueOf(cfg.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, cfg.keyFunc) - } - if err == nil && token.Valid { - // Store user information from token into context. - c.Locals(cfg.ContextKey, token) - return cfg.SuccessHandler(c) - } - return cfg.ErrorHandler(c, err) + return config.SigningKey, nil } } diff --git a/main.go b/main.go new file mode 100644 index 0000000..94b385a --- /dev/null +++ b/main.go @@ -0,0 +1,64 @@ +// 🚀 Fiber is an Express inspired web framework written in Go with 💖 +// 📌 API Documentation: https://fiber.wiki +// 📝 Github Repository: https://github.com/gofiber/fiber +// Special thanks to Echo: https://github.com/labstack/echo/blob/master/middleware/jwt.go + +package jwtware + +import ( + "reflect" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt/v4" +) + +var ( + // defaultRefreshTimeout is the default duration for the context used to create the HTTP request for a refresh of + // the JWKs. + defaultKeyRefreshTimeout = time.Minute + + defaultTokenLookup = "header:" + fiber.HeaderAuthorization +) + +// New ... +func New(config ...Config) fiber.Handler { + cfg := makeCfg(config) + + extractors := cfg.getExtractors() + + // Return middleware handler + return func(c *fiber.Ctx) error { + // Filter request to skip middleware + if cfg.Filter != nil && cfg.Filter(c) { + return c.Next() + } + var auth string + var err error + + for _, extractor := range extractors { + auth, err = extractor(c) + if auth != "" && err == nil { + break + } + } + if err != nil { + return cfg.ErrorHandler(c, err) + } + var token *jwt.Token + + if _, ok := cfg.Claims.(jwt.MapClaims); ok { + token, err = jwt.Parse(auth, cfg.keyFunc) + } else { + t := reflect.ValueOf(cfg.Claims).Type().Elem() + claims := reflect.New(t).Interface().(jwt.Claims) + token, err = jwt.ParseWithClaims(auth, claims, cfg.keyFunc) + } + if err == nil && token.Valid { + // Store user information from token into context. + c.Locals(cfg.ContextKey, token) + return cfg.SuccessHandler(c) + } + return cfg.ErrorHandler(c, err) + } +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..ecf233a --- /dev/null +++ b/main_test.go @@ -0,0 +1,233 @@ +package jwtware_test + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" + + jwtware "github.com/gofiber/jwt/v3" +) + +type TestToken struct { + SigningMethod string + Token string +} + +var ( + hamac = []TestToken{ + { + SigningMethod: jwtware.HS256, + Token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o", + }, + { + SigningMethod: jwtware.HS384, + Token: "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.hO2sthNQUSfvI9ylUdMKDxcrm8jB3KL6Rtkd3FOskL-jVqYh2CK1es8FKCQO8_tW", + }, + { + SigningMethod: jwtware.HS512, + Token: "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.wUVS6tazE2N98_J4SH_djkEe1igXPu0qILAvVXCiO6O20gdf5vZ2sYFWX3c-Hy6L4TD47b3DSAAO9XjSqpJfag", + }, + } + + rsa = []TestToken{ + { + SigningMethod: jwtware.RS256, + Token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImdvZmliZXItcnNhIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.gvWLzl1sYUXdYqAPqFYLEJYtqPce8YxrV6LPiyWX2147llj1YfquFySnC8KOUTykCAxZHe6tFkyyZOp35HOqV3P-jxW2rw05mpNhld79f-O2sAFEzV7qxJXuYi4TL-Qn1gaLWP7i9B6B9c-0xLzYUmtLdrmlM2pxfPkXwG0oSao", + }, + { + SigningMethod: jwtware.RS384, + Token: "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCIsImtpZCI6ImdvZmliZXItcnNhIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.IIFu5jNRT5fIe91we3ARLTpE8hGu4tK6gsWtrJ1lAWzCxUYsVE02yOi3ya9RJsh-37GN8LdfVw74ZQzr4dwuq8SorycVatA2bc_OfkWpioOoPCqGMBFgsEdue0qtL1taflA-YSNG-Qntpqx_ciCGfI1DhiqikLaL-LSe8H9YOWk", + }, + { + SigningMethod: jwtware.RS512, + Token: "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCIsImtpZCI6ImdvZmliZXItcnNhIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.DKY-VXa6JJUZpupEUcmXETwaV2jfLydyeBfhSP8pIEW9g52fQ3g5hrHCNstxG2yy9yU68yrFqrBDetDX_yJ6qSHAOInwGWYot8W4D0lJvqsHJe0W0IPi03xiaWjwKO26xENCUzNNLvSPKPox5DPcg31gzCFBrIUgVX-TkpajuSE", + }, + } + + ecdsa = []TestToken{ + { + SigningMethod: jwtware.ES256, + Token: "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImdvZmliZXItcC0yNTYifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.n6iJptkq2i6Y6gbuc92f2ExT9oXbg7hdMlR5MvkCZjayxBAyfpIGGoQAjMriwEs4rjF5F-DSU8T6eUcDxNhonA", + }, + { + SigningMethod: jwtware.ES384, + Token: "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCIsImtpZCI6ImdvZmliZXItcC0zODQifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.WYGFC6NTSzD1E3Zv7Lyy3m_1l0zoF2tZqvDBxQBXqJN-bStTBzNYnpWZDMN6XMI7OqFbPGlh_Jff4Z4dlf0bieEfenURdtpoLIQI1zPNXoIfaY7TH8BTAXQKtoBk89Ed", + }, + { + SigningMethod: jwtware.ES512, + Token: "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCIsImtpZCI6ImdvZmliZXItcC01MjEifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.ADwlteggILiCM_oCkxsyJTRK6BpQyH2FBQD_Tw_ph0vpLPRrpAkyh_CZIY9uZqqpb3J_eohscCzj5Vo9jrhP9DFRAdvLZCgehLj6N8P9aro2uy9jAl7kowxe0nEErv1SrD9qlyLWJh80jJVHRBVHXXysQ2WUD0KiRBq4x1p8jdEw5vHy", + }, + } +) + +const ( + defaultSigningKey = "secret" + defaultKeySet = ` +{ + "keys":[ + { + "e": "AQAB", + "kid": "gofiber-rsa", + "kty": "RSA", + "n": "2IPZysef6KVySrb_RPopuwWy1C7KRfE96zQ9jIRwPghlvs0yfj9VK4rqeYbuHp5k9ghbjm1Bn2LMLR-JzqYWbchxzVrV58ay4nRHYUSjyzdbNcG0J4W-NxHnVqK0UUOl59uikRDqGHh3eRen_jVO_B8lvhqM57HQhA-czHbsmeU" + }, + { + "crv": "P-256", + "kid": "gofiber-p-256", + "kty": "EC", + "x": "nLZJMz-8B6p2A1-owmTrCZqZx87_Y5soNPW74dQ8EDw", + "y": "RvuLyi0tS-Tcx35IMy6aL_ID0K-cJFXmkFR8t9XJ4pc" + }, + { + "crv": "P-384", + "kid": "gofiber-p-384", + "kty": "EC", + "x": "wvSt-v7az1qbz493ToTSvNcXgdIGqTtlcLzW7B1Ko3QWVgmtBYWQr_Q311_QX9DY", + "y": "DvvBgCVjsDyttGAF8cmTP5maV46PrxACZFLvC1OEiZh-Ul0obSGXqG2xu8ulINPy" + }, + { + "crv": "P-521", + "kid": "gofiber-p-521", + "kty": "EC", + "x": "AZhzdsnk9Dx5fLdPDnYJOI3ClkghbyFvpSq2ExzyPNgjZz_7iBUjyyLtr6QDn9BAaeFvSQFHvhZUylIQZ9wdIinq", + "y": "AC2Me0tRqydVv7d23_0xdjiDndGuk0XpSZL5jeDWQ1_Tuty28-pJrFx38QQmWnosC0lBEdOUjxq-71YP7e4TzRMR" + } + ] +} +` +) + +func TestJwtFromHeader(t *testing.T) { + t.Parallel() + + defer func() { + // Assert + if err := recover(); err != nil { + t.Fatalf("Middleware should not panic") + } + }() + + for _, test := range hamac { + // Arrange + app := fiber.New() + + app.Use(jwtware.New(jwtware.Config{ + SigningKey: []byte(defaultSigningKey), + SigningMethod: test.SigningMethod, + })) + + app.Get("/ok", func(c *fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest("GET", "/ok", nil) + req.Header.Add("Authorization", "Bearer "+test.Token) + + // Act + resp, err := app.Test(req) + + // Assert + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode) + } +} + +func TestJwtFromCookie(t *testing.T) { + t.Parallel() + + defer func() { + // Assert + if err := recover(); err != nil { + t.Fatalf("Middleware should not panic") + } + }() + + for _, test := range hamac { + // Arrange + app := fiber.New() + + app.Use(jwtware.New(jwtware.Config{ + SigningKey: []byte(defaultSigningKey), + SigningMethod: test.SigningMethod, + TokenLookup: "cookie:Token", + })) + + app.Get("/ok", func(c *fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest("GET", "/ok", nil) + cookie := &http.Cookie{ + Name: "Token", + Value: test.Token, + } + req.AddCookie(cookie) + + // Act + resp, err := app.Test(req) + + // Assert + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode) + } +} + +// TestJWKs performs a table test on the JWKs code. +func TestJwkFromServer(t *testing.T) { + // Could add a test with an invalid JWKs endpoint. + // Create a temporary directory to serve the JWKs from. + tempDir, err := ioutil.TempDir("", "*") + if err != nil { + t.Errorf("Failed to create a temporary directory.\nError:%s\n", err.Error()) + t.FailNow() + } + defer func() { + if err = os.RemoveAll(tempDir); err != nil { + t.Errorf("Failed to remove temporary directory.\nError:%s\n", err.Error()) + t.FailNow() + } + }() + + // Create the JWKs file path. + jwksFile := filepath.Join(tempDir, "jwks.json") + + // Write the empty JWKs. + if err = ioutil.WriteFile(jwksFile, []byte(defaultKeySet), 0600); err != nil { + t.Errorf("Failed to write JWKs file to temporary directory.\nError:%s\n", err.Error()) + t.FailNow() + } + + // Create the HTTP test server. + server := httptest.NewServer(http.FileServer(http.Dir(tempDir))) + defer server.Close() + + // Iterate through the test cases. + for _, test := range append(rsa, ecdsa...) { + // Arrange + app := fiber.New() + + app.Use(jwtware.New(jwtware.Config{ + KeySetURL: server.URL + "/jwks.json", + })) + + app.Get("/ok", func(c *fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest("GET", "/ok", nil) + req.Header.Add("Authorization", "Bearer "+test.Token) + + // Act + resp, err := app.Test(req) + + // Assert + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode) + } +}