Skip to content

Commit

Permalink
Addressed CR comments
Browse files Browse the repository at this point in the history
    -- Tokens are now in WLCG format
    -- Added scope check
    -- Added comment clarifying issuer claim retrieval
    -- Parse federation jwks correctly (as json)
    -- Changed ES512 to ES256
  • Loading branch information
turetske committed Oct 5, 2023
1 parent 1312653 commit 209e3a0
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 48 deletions.
38 changes: 36 additions & 2 deletions web_ui/prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func checkPromToken(av1 *route.Router) gin.HandlerFunc {
return
}

// Parsing the token (unverified) in order to get its issuer without having the jwks
token, err = jwt.Parse([]byte(strToken), jwt.WithVerify(false))

if err != nil {
Expand All @@ -189,11 +190,16 @@ func checkPromToken(av1 *route.Router) gin.HandlerFunc {
c.JSON(400, gin.H{"error": "Failed to read federation key file"})
return
}
key, err := jwk.ParseKey(contents, jwk.WithPEM(true))
keys, err := jwk.Parse(contents)
//key, err := jwk.ParseKey(contents, jwk.WithPEM(true))
if err != nil {
c.JSON(400, gin.H{"error": "Failed to parse Federation key file"})
return
}
key, ok := keys.Key(0)
if !ok {
c.JSON(400, gin.H{"error": "No key in keyset"})
}
bKey = &key
} else {
bKey, err = pelican_config.GetOriginJWK()
Expand All @@ -209,13 +215,41 @@ func checkPromToken(av1 *route.Router) gin.HandlerFunc {
return
}

_, err = jwt.Parse([]byte(strToken), jwt.WithKey(jwa.ES512, raw.PublicKey), jwt.WithValidate(true))
parsed, err := jwt.Parse([]byte(strToken), jwt.WithKey(jwa.ES256, raw.PublicKey))

if err != nil {
c.JSON(403, gin.H{"error": "Permission Denied: Invalid token"})
return
}

/*
* The signature is verified, now we need to make sure this token actually gives us
* permission to access prometheus metrics
* NOTE: The validate function also handles checking `iat` and `exp` to make sure the token
* remains valid.
*/
scopeValidator := jwt.ValidatorFunc(func(_ context.Context, tok jwt.Token) jwt.ValidationError {
scope_any, present := tok.Get("scope")
if !present {
return jwt.NewValidationError(errors.New("No scope is present; required for authorization"))
}
scope, ok := scope_any.(string)
if !ok {
return jwt.NewValidationError(errors.New("scope claim in token is not string-valued"))
}

for _, scope := range strings.Split(scope, " ") {
if scope == "prometheus.read" {
return nil
}
}
return jwt.NewValidationError(errors.New("Token does not contain prometheus access authorization"))
})
if err = jwt.Validate(parsed, jwt.WithValidator(scopeValidator)); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server could not validate the provided access token"})
return
}

av1.ServeHTTP(c.Writer, c.Request)
}
}
Expand Down
235 changes: 189 additions & 46 deletions web_ui/prometheus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -23,21 +25,26 @@ import (
"github.com/stretchr/testify/assert"
)

func TestPrometheusProtection(t *testing.T) {
func TestPrometheusProtectionFederationURL(t *testing.T) {

/*
* Tests that prometheus metrics are behind the origin's and federation's token. Specifically it signs a token
* with the origin's key and invokes a prometheus GET endpoint with both URL and Header authorization, with the
* URL authorization, it mimics matching the Federation URL to ensure that check is done, but intercepts with
* returning the origin jwk for testing purposes.
* This then does so again with an invalid token and confirms that the correct error is returned
* Tests that prometheus metrics are behind federation's token. Specifically it signs a token
* with the a generated key o prometheus GET endpoint with both URL. It mimics matching the Federation URL
* to ensure that check is done, but intercepts with returning a generated jwk for testing purposes
*/

// Setup httptest recorder and context for the the unit test
viper.Reset()

av1 := route.New().WithPrefix("/api/v1.0/prometheus")

// Create temp dir for the origin key file
tDir := t.TempDir()
kfile := filepath.Join(tDir, "testKey")

//Setup a private key and a token
viper.Set("IssuerKey", kfile)

w := httptest.NewRecorder()
c, r := gin.CreateTestContext(w)
// Note, this handler function intercepts the "http.Get call to the federation uri
Expand All @@ -57,21 +64,27 @@ func TestPrometheusProtection(t *testing.T) {
URL: &url.URL{},
}

// Create temp dir for the origin key file
tDir := t.TempDir()
kfile := filepath.Join(tDir, "testKey")
privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
assert.NoError(t, err, "Error generating private key")

//Setup a private key and a token
viper.Set("IssuerKey", kfile)
// Convert from raw ecdsa to jwk.Key
pKey, err := jwk.FromRaw(privateKey)
assert.NoError(t, err, "Unable to convert ecdsa.PrivateKey to jwk.Key")

// Generate the origin private and public keys
_, err := config.LoadPublicKey("", kfile)
//Assign Key id to the private key
err = jwk.AssignKeyID(pKey)
assert.NoError(t, err, "Error assigning kid to private key")

//Set an algorithm for the key
err = pKey.Set(jwk.AlgorithmKey, jwa.ES512)
assert.NoError(t, err, "Unable to set algorithm for pKey")

buf, err := json.MarshalIndent(pKey, "", " ")
if err != nil {
t.Fatal(err)
}
err = os.WriteFile(kfile, buf, 0644)

privKey, err := config.LoadPrivateKey(kfile)
if err != nil {
t.Fatal(err)
}
Expand All @@ -80,20 +93,32 @@ func TestPrometheusProtection(t *testing.T) {
issuerURL := url.URL{}
issuerURL.Scheme = "https"
issuerURL.Host = "test-http"
now := time.Now()

jti_bytes := make([]byte, 16)
_, err = rand.Read(jti_bytes)
if err != nil {
t.Fatal(err)
}
jti := base64.RawURLEncoding.EncodeToString(jti_bytes)

originUrl := viper.GetString("OriginURL")
tok, err := jwt.NewBuilder().
Claim("scope", "prometheus.read").
Claim("wlcg.ver", "1.0").
JwtID(jti).
Issuer(issuerURL.String()).
IssuedAt(now).
Expiration(now.Add(30 * time.Minute)).
NotBefore(now).
Audience([]string{originUrl}).
Subject("sub").
Expiration(time.Now().Add(time.Minute)).
IssuedAt(time.Now()).
Build()

if err != nil {
t.Fatal(err)
}

// Sign the token with the origin private key
signed, err := jwt.Sign(tok, jwt.WithKey(jwa.ES512, privKey))
signed, err := jwt.Sign(tok, jwt.WithKey(jwa.ES256, pKey))

if err != nil {
t.Fatal(err)
Expand All @@ -115,31 +140,96 @@ func TestPrometheusProtection(t *testing.T) {
c.Request.URL.RawQuery = new_query.Encode()

r.ServeHTTP(w, c.Request)
}

// Check to see that the code exits with status code 404 after giving it a good token
assert.Equal(t, 404, w.Result().StatusCode, "Expected status code of 404 representing failure due to minimal server setup, not token check")
func TestPrometheusProtectionOriginHeaderScope(t *testing.T) {
/*
* Tests that the prometheus protections are behind the origin's token and tests that the token is accessable from
* the header function. It signs a token with the origin's jwks key and adds it to the header before attempting
* to access the prometheus metrics. It then attempts to access the metrics with a token with an invalid scope.
* It attempts to do so again with a token signed by a bad key. Both these are expected to fail.
*/

// Create a new Recorder and Context for the next HTTPtest call
wH := httptest.NewRecorder()
cH, rH := gin.CreateTestContext(wH)
viper.Reset()

av1 := route.New().WithPrefix("/api/v1.0/prometheus")

// Create temp dir for the origin key file
tDir := t.TempDir()
kfile := filepath.Join(tDir, "testKey")

//Setup a private key and a token
viper.Set("IssuerKey", kfile)

w := httptest.NewRecorder()
c, r := gin.CreateTestContext(w)

c.Request = &http.Request{
URL: &url.URL{},
}

// Generate the origin private and public keys
_, err := config.LoadPublicKey("", kfile)

if err != nil {
t.Fatal(err)
}

// Load the private key
privKey, err := config.LoadPrivateKey(kfile)
if err != nil {
t.Fatal(err)
}

// Create a token
issuerURL := url.URL{}
issuerURL.Scheme = "https"
issuerURL.Host = "test-http"

jti_bytes := make([]byte, 16)
_, err = rand.Read(jti_bytes)
if err != nil {
t.Fatal(err)
}
jti := base64.RawURLEncoding.EncodeToString(jti_bytes)

originUrl := viper.GetString("OriginURL")
tok, err := jwt.NewBuilder().
Claim("scope", "prometheus.read").
Claim("wlcg.ver", "1.0").
JwtID(jti).
Issuer(issuerURL.String()).
Audience([]string{originUrl}).
Subject("sub").
Expiration(time.Now().Add(time.Minute)).
IssuedAt(time.Now()).
Build()

if err != nil {
t.Fatal(err)
}

// Sign the token with the origin private key
signed, err := jwt.Sign(tok, jwt.WithKey(jwa.ES256, privKey))
if err != nil {
t.Fatal(err)
}

// Set the request to go through the checkPromToken function
rH.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))
cH.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`)))
r.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))
c.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`)))

// Put the signed token within the header
cH.Request.Header.Set("Authorization", "Bearer "+string(signed))
cH.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("Authorization", "Bearer "+string(signed))
c.Request.Header.Set("Content-Type", "application/json")

viper.Set("FederationURL", "")
r.ServeHTTP(w, c.Request)

rH.ServeHTTP(wH, cH.Request)
// Check to see that the code exits with status code 404 after given it a good token
assert.Equal(t, 404, wH.Result().StatusCode, "Expected status code of 404 representing failure due to minimal server setup, not token check")
assert.Equal(t, 404, w.Result().StatusCode, "Expected status code of 404 representing failure due to minimal server setup, not token check")

// Create a new Recorder and Context for testing an invalid token
wI := httptest.NewRecorder()
cI, rI := gin.CreateTestContext(wI)
// Create a new Recorder and Context for the next HTTPtest call
w = httptest.NewRecorder()
c, r = gin.CreateTestContext(w)

c.Request = &http.Request{
URL: &url.URL{},
Expand All @@ -158,30 +248,83 @@ func TestPrometheusProtection(t *testing.T) {
assert.NoError(t, err, "Error assigning kid to private key")

//Set an algorithm for the key
err = pKey.Set(jwk.AlgorithmKey, jwa.ES512)
err = pKey.Set(jwk.AlgorithmKey, jwa.ES256)
assert.NoError(t, err, "Unable to set algorithm for pKey")

jti_bytes = make([]byte, 16)
_, err = rand.Read(jti_bytes)
if err != nil {
t.Fatal(err)
}
jti = base64.RawURLEncoding.EncodeToString(jti_bytes)

// Create a new token to be used
tok, err = jwt.NewBuilder().
Claim("scope", "prometheus.read").
Claim("wlcg.ver", "1.0").
JwtID(jti).
Issuer(issuerURL.String()).
IssuedAt(now).
Expiration(now.Add(30 * time.Minute)).
NotBefore(now).
Audience([]string{originUrl}).
Subject("sub").
Expiration(time.Now().Add(time.Minute)).
IssuedAt(time.Now()).
Build()

assert.NoError(t, err, "Error creating token")

// Sign token with private key (not the origin)
signed, err = jwt.Sign(tok, jwt.WithKey(jwa.ES512, pKey))
signed, err = jwt.Sign(tok, jwt.WithKey(jwa.ES256, pKey))
assert.NoError(t, err, "Error signing token")

rI.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))
cI.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`)))
r.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))
c.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`)))

cI.Request.Header.Set("Authorization", "Bearer "+string(signed))
cI.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("Authorization", "Bearer "+string(signed))
c.Request.Header.Set("Content-Type", "application/json")

rI.ServeHTTP(wI, cI.Request)
r.ServeHTTP(w, c.Request)
// Assert that it gets the correct Permission Denied 403 code
assert.Equal(t, 403, wI.Result().StatusCode, "Expected failing status code of 403: Permission Denied")
assert.Equal(t, 403, w.Result().StatusCode, "Expected failing status code of 403: Permission Denied")

// Create a new Recorder and Context for the next HTTPtest call
w = httptest.NewRecorder()
c, r = gin.CreateTestContext(w)

c.Request = &http.Request{
URL: &url.URL{},
}

// Create a new token to be used
tok, err = jwt.NewBuilder().
Claim("scope", "not.prometheus").
Claim("wlcg.ver", "1.0").
JwtID(jti).
Issuer(issuerURL.String()).
Audience([]string{originUrl}).
Subject("sub").
Expiration(time.Now().Add(time.Minute)).
IssuedAt(time.Now()).
Build()

if err != nil {
t.Fatal(err)
}

// Sign the token with the origin private key
signed, err = jwt.Sign(tok, jwt.WithKey(jwa.ES256, privKey))
if err != nil {
t.Fatal(err)
}

// Set the request to go through the checkPromToken function
r.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))
c.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`)))

// Put the signed token within the header
c.Request.Header.Set("Authorization", "Bearer "+string(signed))
c.Request.Header.Set("Content-Type", "application/json")

r.ServeHTTP(w, c.Request)

assert.Equal(t, 500, w.Result().StatusCode, "Expected status code of 500 representing failure to validate token scope")
}

0 comments on commit 209e3a0

Please sign in to comment.