Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - Add nonce validation in PoP token verifier #367

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auth/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) {

c.verifier = provider.Verifier(&oidc.Config{SkipClientIDCheck: !opts.VerifyClientID, ClientID: opts.ClientID})
if opts.EnablePOP {
c.popTokenVerifier = NewPoPVerifier(c.POPTokenHostname, c.PoPTokenValidityDuration)
c.popTokenVerifier = NewPoPVerifier(c.POPTokenHostname, c.PoPTokenValidityDuration, 1*time.Minute)
}

switch opts.AuthMode {
Expand Down
57 changes: 56 additions & 1 deletion auth/providers/azure/pop_tokenverifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,56 @@ import (
"math/big"
"strconv"
"strings"
"sync"
"time"

"github.com/pkg/errors"
"gopkg.in/square/go-jose.v2/jwt"
"k8s.io/klog/v2"
)

// create a cache to save nonce claim to make sure the nonce is not reused
var nonceMap = nonceCache{v: make(map[string]time.Time)}

type nonceCache struct {
mu sync.Mutex
v map[string]time.Time
}

func (c *nonceCache) AddToCache(key string) error {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.v[key]; ok {
return errors.Errorf("nonce claim already exists")
}
c.v[key] = time.Now()
return nil
}

func (c *nonceCache) RemoveFromCache(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.v, key)
}

func (c *nonceCache) GetCounter() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.v)
}

// PopTokenVerifier is validator for PoP tokens.
type PoPTokenVerifier struct {
hostName string
PoPTokenValidityDuration time.Duration
mapCacheRetentionBuffer time.Duration
julienstroheker marked this conversation as resolved.
Show resolved Hide resolved
}

func NewPoPVerifier(hostName string, popTokenValidityDuration time.Duration) *PoPTokenVerifier {
func NewPoPVerifier(hostName string, popTokenValidityDuration, cacheRetentionBuffer time.Duration) *PoPTokenVerifier {
return &PoPTokenVerifier{
PoPTokenValidityDuration: popTokenValidityDuration,
hostName: hostName,
mapCacheRetentionBuffer: cacheRetentionBuffer,
}
}

Expand Down Expand Up @@ -210,6 +243,28 @@ func (p *PoPTokenVerifier) ValidatePopToken(token string) (string, error) {
return "", errors.Errorf("RSA verify err: %+v", err)
}

// Verify host 'nonce' claim
var nonce string
if nonceClaim, ok := claims["nonce"]; ok {
if _, ok := nonceClaim.(string); !ok {
return "", errors.Errorf("Invalid token. 'nonce' claim should be of type string")
}
nonce = nonceClaim.(string)
} else {
return "", errors.Errorf("Invalid token. 'nonce' claim is missing")
}
// Making sure nonce is not reused
err = nonceMap.AddToCache(nonce)
if err != nil {
return "", errors.Errorf("Invalid token. 'nonce' claim is reused")
}
klog.V(6).Infof("nonce claim added to the cache. Cache size is: %d", nonceMap.GetCounter())
// Cleaning cached nonce token after PoPTokenValidityDuration minutes + 1 minute by default
go func() {
julienstroheker marked this conversation as resolved.
Show resolved Hide resolved
time.Sleep(p.PoPTokenValidityDuration + p.mapCacheRetentionBuffer)
nonceMap.RemoveFromCache(nonce)
}()

return claims["at"].(string), nil
}

Expand Down
32 changes: 31 additions & 1 deletion auth/providers/azure/pop_tokenverifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
)

func TestPopTokenVerifier_Verify(t *testing.T) {
verifier := azure.NewPoPVerifier("testHostname", 15*time.Minute)
verifier := azure.NewPoPVerifier("testHostname", 15*time.Minute, 1*time.Minute)

// Test cases where no error is expected
noErrorTestCases := []struct {
Expand Down Expand Up @@ -158,6 +158,18 @@ func TestPopTokenVerifier_Verify(t *testing.T) {
hostname: "testHostname",
errString: "Invalid token. access token missing",
},
{
desc: "'nonce' claim in the payload is missing",
kid: azure.NonceClaimMissing,
hostname: "testHostname",
errString: "Invalid token. 'nonce' claim is missing",
},
{
desc: "'nonce' claim in the payload is not a string",
kid: azure.NonceClaimNotString,
hostname: "testHostname",
errString: "Invalid token. 'nonce' claim should be of type string",
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
Expand Down Expand Up @@ -219,4 +231,22 @@ func TestPopTokenVerifier_Verify(t *testing.T) {
_, err := verifier.ValidatePopToken(invalidToken)
assert.Containsf(t, err.Error(), "Token is expired", "Error message is not as expected")
})

t.Run("'nonce' claim been reused - Cache validation", func(t *testing.T) {
// Setting up the verifier expiration time to 4 seconds and cache expiration time to -3 seconds.
// This will ensure that the cache is expired before the verifier on the second call after 1sec.
nonceVerifier := azure.NewPoPVerifier("testHostname", 4*time.Second, -3*time.Second)
// Generating a first valid token with a nonce claim hard coded to be reused. This should pass the validation.
validToken, _ := azure.NewPoPTokenBuilder().SetTimestamp(time.Now().Unix()).SetHostName("testHostname").SetKid(azure.NonceClaimHardcoded).GetToken()
_, err := nonceVerifier.ValidatePopToken(validToken)
assert.NoError(t, err)
// Generating a second valid token with a nonce claim hard coded to be reused. This should fail the validation because the value is cached.
validToken, _ = azure.NewPoPTokenBuilder().SetTimestamp(time.Now().Unix()).SetHostName("testHostname").SetKid(azure.NonceClaimHardcoded).GetToken()
_, err = nonceVerifier.ValidatePopToken(validToken)
assert.Containsf(t, err.Error(), "Invalid token. 'nonce' claim is reused", "Error message is not as expected")
// Sleeping for 2 seconds to ensure that the cache is expired before the verifier is called again.
time.Sleep(2 * time.Second)
_, err = nonceVerifier.ValidatePopToken(validToken)
julienstroheker marked this conversation as resolved.
Show resolved Hide resolved
assert.NoError(t, err)
})
}
43 changes: 28 additions & 15 deletions auth/providers/azure/pop_tokenverifier_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ const (
TsClaimsTypeUnknown = "tsClaimsTypeUnknown"
UClaimsWrongType = "uClaimsWrongType"
SignatureWrongType = "signatureWrongType"
NonceClaimMissing = "nonceMissing"
NonceClaimHardcoded = "nonceHardcoded"
NonceClaimNotString = "nonceNotString"
)

// A struct that represents a PoP token
Expand All @@ -173,6 +176,7 @@ type PoPTokenBuilderImpl struct {
hostName string
kid string // used for testing purposes
token PoPToken
nonce string
}

// A constructor function that returns a new PoPTokenBuilderImpl
Expand Down Expand Up @@ -231,8 +235,11 @@ func (b *PoPTokenBuilderImpl) SetPayload() error {
cnf = b.popKey.KeyID()
}

nonce := uuid.New().String()
nonce = strings.Replace(nonce, "-", "", -1)
if b.kid == NonceClaimHardcoded {
b.nonce = "hardcodedNonce"
} else {
b.nonce = strings.Replace(uuid.New().String(), "-", "", -1)
}

accessTokenData := fmt.Sprintf(popAccessToken, time.Now().Add(time.Minute*5).Unix(), cnf)
if b.kid == AtCnfClaimMissing {
Expand All @@ -247,42 +254,48 @@ func (b *PoPTokenBuilderImpl) SetPayload() error {
return fmt.Errorf("Error when generating token. Error:%+v", err)
}

payload := fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, at, b.ts, b.hostName, b.popKey.Jwk(), nonce)
payload := fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, at, b.ts, b.hostName, b.popKey.Jwk(), b.nonce)
if b.kid == TsClaimsMissing {
payload = fmt.Sprintf(`{ "at" : "%s", "u": "%d", "cnf":{"jwk":%s}, "nonce":"%s"}`, at, 1, b.popKey.Jwk(), nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "u": "%d", "cnf":{"jwk":%s}, "nonce":"%s"}`, at, 1, b.popKey.Jwk(), b.nonce)
}
if b.kid == UClaimsMissing {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "cnf":{"jwk":%s}, "nonce":"%s"}`, at, b.ts, b.popKey.Jwk(), nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "cnf":{"jwk":%s}, "nonce":"%s"}`, at, b.ts, b.popKey.Jwk(), b.nonce)
}
if b.kid == CnfJwkClaimsEmpty {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{}, "nonce":"%s"}`, at, b.ts, b.hostName, nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{}, "nonce":"%s"}`, at, b.ts, b.hostName, b.nonce)
}
if b.kid == CnfJwkClaimsMissing {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":1, "nonce":"%s"}`, at, b.ts, b.hostName, nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":1, "nonce":"%s"}`, at, b.ts, b.hostName, b.nonce)
}
if b.kid == CnfJwkClaimsWrong {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{"jwk":1}, "nonce":"%s"}`, at, b.ts, b.hostName, nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{"jwk":1}, "nonce":"%s"}`, at, b.ts, b.hostName, b.nonce)
}
if b.kid == CnfClaimsMissing {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "nonce": "%s"}`, at, b.ts, b.hostName, nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "nonce": "%s"}`, at, b.ts, b.hostName, b.nonce)
}
if b.kid == TsClaimsTypeString {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : "%s", "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, at, strconv.FormatInt(b.ts, 10), b.hostName, b.popKey.Jwk(), nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : "%s", "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, at, strconv.FormatInt(b.ts, 10), b.hostName, b.popKey.Jwk(), b.nonce)
}
if b.kid == TsClaimsTypeUnknown {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %t, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, at, bool(true), b.hostName, b.popKey.Jwk(), nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %t, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, at, bool(true), b.hostName, b.popKey.Jwk(), b.nonce)
}
if b.kid == AtClaimsWrongType {
payload = fmt.Sprintf(`{ "at" : %d, "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, 12, b.ts, b.hostName, b.popKey.Jwk(), nonce)
payload = fmt.Sprintf(`{ "at" : %d, "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, 12, b.ts, b.hostName, b.popKey.Jwk(), b.nonce)
}
if b.kid == UClaimsWrongType {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": %d, "cnf":{"jwk":%s}, "nonce":"%s"}`, at, b.ts, 1, b.popKey.Jwk(), nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": %d, "cnf":{"jwk":%s}, "nonce":"%s"}`, at, b.ts, 1, b.popKey.Jwk(), b.nonce)
}
if b.kid == AtClaimsMissing {
payload = fmt.Sprintf(`{ "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, b.ts, b.hostName, b.popKey.Jwk(), nonce)
payload = fmt.Sprintf(`{ "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, b.ts, b.hostName, b.popKey.Jwk(), b.nonce)
}
if b.kid == AtClaimIncorrect {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, fmt.Sprintf("%s.%s.%s", BadTokenKey, BadTokenKey, BadTokenKey), b.ts, b.hostName, b.popKey.Jwk(), nonce)
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":"%s"}`, fmt.Sprintf("%s.%s.%s", BadTokenKey, BadTokenKey, BadTokenKey), b.ts, b.hostName, b.popKey.Jwk(), b.nonce)
}
if b.kid == NonceClaimMissing {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{"jwk":%s}}`, at, b.ts, b.hostName, b.popKey.Jwk())
}
if b.kid == NonceClaimNotString {
payload = fmt.Sprintf(`{ "at" : "%s", "ts" : %d, "u": "%s", "cnf":{"jwk":%s}, "nonce":%d}`, at, b.ts, b.hostName, b.popKey.Jwk(), 1)
}
b.token.Payload = base64.RawURLEncoding.EncodeToString([]byte(payload))
return nil
Expand Down
Loading