Skip to content

Commit

Permalink
Enhance token checking (woodpecker-ci#3842)
Browse files Browse the repository at this point in the history
(cherry picked from commit b8b6efb)
  • Loading branch information
anbraten authored and 0x1def committed Jun 27, 2024
1 parent f845420 commit e6e9d57
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 31 deletions.
2 changes: 1 addition & 1 deletion server/api/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func PostHook(c *gin.Context) {
//

// get the token and verify the hook is authorized
parsedToken, err := token.ParseRequest(c.Request, func(_ *token.Token) (string, error) {
parsedToken, err := token.ParseRequest([]token.Type{token.HookToken}, c.Request, func(_ *token.Token) (string, error) {
return repo.Hash, nil
})
if err != nil {
Expand Down
13 changes: 5 additions & 8 deletions server/router/middleware/session/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,14 @@ func AuthorizeAgent(c *gin.Context) {
return
}

parsed, err := token.ParseRequest(c.Request, func(_ *token.Token) (string, error) {
_, err := token.ParseRequest([]token.Type{token.AgentToken}, c.Request, func(_ *token.Token) (string, error) {
return secret, nil
})
switch {
case err != nil:
if err != nil {
c.String(http.StatusInternalServerError, "invalid or empty token. %s", err)
c.Abort()
case parsed.Kind != token.AgentToken:
c.String(http.StatusForbidden, "invalid token. please use an agent token")
c.Abort()
default:
c.Next()
return
}

c.Next()
}
4 changes: 2 additions & 2 deletions server/router/middleware/session/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func SetUser() gin.HandlerFunc {
return func(c *gin.Context) {
var user *model.User

t, err := token.ParseRequest(c.Request, func(t *token.Token) (string, error) {
t, err := token.ParseRequest([]token.Type{token.UserToken, token.SessToken}, c.Request, func(t *token.Token) (string, error) {
var err error
userID, err := strconv.ParseInt(t.Get("user-id"), 10, 64)
if err != nil {
Expand All @@ -58,7 +58,7 @@ func SetUser() gin.HandlerFunc {
// if this is a session token (ie not the API token)
// this means the user is accessing with a web browser,
// so we should implement CSRF protection measures.
if t.Kind == token.SessToken {
if t.Type == token.SessToken {
err = token.CheckCsrf(c.Request, func(_ *token.Token) (string, error) {
return user.Hash, nil
})
Expand Down
56 changes: 36 additions & 20 deletions shared/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,52 @@ import (

type SecretFunc func(*Token) (string, error)

type Type string

const (
UserToken = "user"
SessToken = "sess"
HookToken = "hook"
CsrfToken = "csrf"
AgentToken = "agent"
UserToken Type = "user" // user token (exp cli)
SessToken Type = "sess" // session token (ui token requires csrf check)
HookToken Type = "hook" // repo hook token
CsrfToken Type = "csrf"
AgentToken Type = "agent"
)

// SignerAlgo id default algorithm used to sign JWT tokens.
const SignerAlgo = "HS256"

type Token struct {
Kind string
Type Type
claims jwt.MapClaims
}

func parse(raw string, fn SecretFunc) (*Token, error) {
func Parse(allowedTypes []Type, raw string, fn SecretFunc) (*Token, error) {
token := &Token{
claims: jwt.MapClaims{},
}
parsed, err := jwt.Parse(raw, keyFunc(token, fn))
if err != nil {
return nil, err
} else if !parsed.Valid {
}
if !parsed.Valid {
return nil, jwt.ErrTokenUnverifiable
}

hasAllowedType := false
for _, k := range allowedTypes {
if k == token.Type {
hasAllowedType = true
break
}
}

if !hasAllowedType {
return nil, jwt.ErrInvalidType
}

return token, nil
}

func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) {
func ParseRequest(allowedTypes []Type, r *http.Request, fn SecretFunc) (*Token, error) {
// first we attempt to get the token from the
// authorization header.
token := r.Header.Get("Authorization")
Expand All @@ -63,19 +79,19 @@ func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) {
if _, err := fmt.Sscanf(token, "Bearer %s", &bearer); err != nil {
return nil, err
}
return parse(bearer, fn)
return Parse(allowedTypes, bearer, fn)
}

token = r.Header.Get("X-Gitlab-Token")
if len(token) != 0 {
return parse(token, fn)
return Parse(allowedTypes, token, fn)
}

// then we attempt to get the token from the
// access_token url query parameter
token = r.FormValue("access_token")
if len(token) != 0 {
return parse(token, fn)
return Parse(allowedTypes, token, fn)
}

// and finally we attempt to get the token from
Expand All @@ -84,7 +100,7 @@ func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) {
if err != nil {
return nil, err
}
return parse(cookie.Value, fn)
return Parse(allowedTypes, cookie.Value, fn)
}

func CheckCsrf(r *http.Request, fn SecretFunc) error {
Expand All @@ -97,12 +113,12 @@ func CheckCsrf(r *http.Request, fn SecretFunc) error {

// parse the raw CSRF token value and validate
raw := r.Header.Get("X-CSRF-TOKEN")
_, err := parse(raw, fn)
_, err := Parse([]Type{CsrfToken}, raw, fn)
return err
}

func New(kind string) *Token {
return &Token{Kind: kind, claims: jwt.MapClaims{}}
func New(tokenType Type) *Token {
return &Token{Type: tokenType, claims: jwt.MapClaims{}}
}

// Sign signs the token using the given secret hash
Expand All @@ -124,7 +140,7 @@ func (t *Token) SignExpires(secret string, exp int64) (string, error) {
claims[k] = v
}

claims["type"] = t.Kind
claims["type"] = t.Type
if exp > 0 {
claims["exp"] = float64(exp)
}
Expand Down Expand Up @@ -157,12 +173,12 @@ func keyFunc(token *Token, fn SecretFunc) jwt.Keyfunc {
return nil, jwt.ErrSignatureInvalid
}

// extract the token kind and cast to the expected type
kind, ok := claims["type"]
// extract the token type and cast to the expected type
tokenType, ok := claims["type"].(string)
if !ok {
return nil, jwt.ErrInvalidType
}
token.Kind, _ = kind.(string)
token.Type = Type(tokenType)

// copy custom claims
for k, v := range claims {
Expand Down
62 changes: 62 additions & 0 deletions shared/token/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package token_test

import (
"testing"

"github.com/franela/goblin"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"

"go.woodpecker-ci.org/woodpecker/v2/shared/token"
)

func TestToken(t *testing.T) {
gin.SetMode(gin.TestMode)

g := goblin.Goblin(t)
g.Describe("Token", func() {
jwtSecret := "secret-to-sign-the-token"

g.It("should parse a valid token", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)

parsed, err := token.Parse([]token.Type{token.UserToken}, signedToken, func(_ *token.Token) (string, error) {
return jwtSecret, nil
})

assert.NoError(g, err)
assert.NotNil(g, parsed)
assert.Equal(g, "1", parsed.Get("user-id"))
})

g.It("should fail to parse a token with a wrong type", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)

_, err = token.Parse([]token.Type{token.AgentToken}, signedToken, func(_ *token.Token) (string, error) {
return jwtSecret, nil
})

assert.ErrorIs(g, err, jwt.ErrInvalidType)
})

g.It("should fail to parse a token with a wrong secret", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)

_, err = token.Parse([]token.Type{token.UserToken}, signedToken, func(_ *token.Token) (string, error) {
return "this-is-a-wrong-secret", nil
})

assert.ErrorIs(g, err, jwt.ErrSignatureInvalid)
})
})
}

0 comments on commit e6e9d57

Please sign in to comment.