Skip to content

Commit

Permalink
Add custom jwt extractor to jwt config
Browse files Browse the repository at this point in the history
  • Loading branch information
RashadAnsari authored and aldas committed Dec 20, 2021
1 parent 6b5e62b commit 4fffee2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
49 changes: 28 additions & 21 deletions middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,14 @@ type (
// - "form:<name>"
// Multiply sources example:
// - "header: Authorization,cookie: myowncookie"

TokenLookup string

// TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context.
// This is one of the two options to provide a token extractor.
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
// You can also provide both if you want.
TokenLookupFuncs []TokenLookupFunc

// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
Expand Down Expand Up @@ -103,7 +108,8 @@ type (
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(error, echo.Context) error

jwtExtractor func(echo.Context) (string, error)
// TokenLookupFunc defines a function for extracting JWT token from the given context.
TokenLookupFunc func(echo.Context) (string, error)
)

// Algorithms
Expand All @@ -120,13 +126,14 @@ var (
var (
// DefaultJWTConfig is the default JWT auth middleware config.
DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper,
SigningMethod: AlgorithmHS256,
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
KeyFunc: nil,
Skipper: DefaultSkipper,
SigningMethod: AlgorithmHS256,
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
TokenLookupFuncs: nil,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
KeyFunc: nil,
}
)

Expand Down Expand Up @@ -163,7 +170,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.Claims == nil {
config.Claims = DefaultJWTConfig.Claims
}
if config.TokenLookup == "" {
if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 {
config.TokenLookup = DefaultJWTConfig.TokenLookup
}
if config.AuthScheme == "" {
Expand All @@ -179,7 +186,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Initialize
// Split sources
sources := strings.Split(config.TokenLookup, ",")
var extractors []jwtExtractor
var extractors = config.TokenLookupFuncs
for _, source := range sources {
parts := strings.Split(source, ":")

Expand Down Expand Up @@ -290,8 +297,8 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
return config.SigningKey, nil
}

// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor {
// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
l := len(authScheme)
Expand All @@ -302,8 +309,8 @@ func jwtFromHeader(header string, authScheme string) jwtExtractor {
}
}

// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
func jwtFromQuery(param string) jwtExtractor {
// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string.
func jwtFromQuery(param string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
Expand All @@ -313,8 +320,8 @@ func jwtFromQuery(param string) jwtExtractor {
}
}

// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string.
func jwtFromParam(param string) jwtExtractor {
// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string.
func jwtFromParam(param string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
token := c.Param(param)
if token == "" {
Expand All @@ -324,8 +331,8 @@ func jwtFromParam(param string) jwtExtractor {
}
}

// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
func jwtFromCookie(name string) jwtExtractor {
// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie.
func jwtFromCookie(name string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
Expand All @@ -335,8 +342,8 @@ func jwtFromCookie(name string) jwtExtractor {
}
}

// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field.
func jwtFromForm(name string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
Expand Down
24 changes: 24 additions & 0 deletions middleware/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,3 +603,27 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {

assert.Equal(t, http.StatusTeapot, res.Code)
}

func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
e := echo.New()

e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

e.Use(JWTWithConfig(JWTConfig{
TokenLookupFuncs: []TokenLookupFunc{
func(c echo.Context) (string, error) {
return c.Request().Header.Get("X-API-Key"), nil
},
},
SigningKey: []byte("secret"),
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)

assert.Equal(t, http.StatusOK, res.Code)
}

0 comments on commit 4fffee2

Please sign in to comment.