Skip to content

Commit

Permalink
Added token checking for prometheus metrics
Browse files Browse the repository at this point in the history
    -- Added a checkPromToken function in prometheus.go to check the tokens
    -- Added unit tests for checking the metrics
    -- Fixed missing "Bearer " in redirect_test
  • Loading branch information
turetske committed Oct 2, 2023
1 parent e0c7d48 commit 7177259
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 2 deletions.
2 changes: 1 addition & 1 deletion director/redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func TestDirectorRegistration(t *testing.T) {
rInv.POST("/", RegisterOrigin)
cInv.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBuffer([]byte(`{"Namespaces": [{"Path": "/foo/bar", "URL": "https://get-your-tokens.org"}]}`)))

cInv.Request.Header.Set("Authorization", string(signedInv))
cInv.Request.Header.Set("Authorization", "Bearer "+string(signedInv))
cInv.Request.Header.Set("Content-Type", "application/json")

rInv.ServeHTTP(wInv, cInv.Request)
Expand Down
50 changes: 49 additions & 1 deletion web_ui/prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package web_ui

import (
"context"
"crypto/ecdsa"
"errors"
"fmt"
"math"
Expand All @@ -35,6 +36,8 @@ import (
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/grafana/regexp"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/mwitkow/go-conntrack"
"github.com/oklog/run"
pelican_config "github.com/pelicanplatform/pelican/config"
Expand Down Expand Up @@ -138,6 +141,51 @@ func runtimeInfo() (api_v1.RuntimeInfo, error) {
return api_v1.RuntimeInfo{}, nil
}

func checkPromToken(av1 *route.Router) gin.HandlerFunc {
/* A function which wraps around the av1 router to force a jwk token check using
* the origin's private key. It will check the request's URL and Header for a token
* and if found it will then attempt to validate the token. If valid, it will continue
* the routing as normal, otherwise it will return an error"
*/
return func(c *gin.Context) {
req := c.Request
var token string
if authzQuery := req.URL.Query()["authz"]; len(authzQuery) > 0 {
token = authzQuery[0]
} else if authzHeader := req.Header["Authorization"]; len(authzHeader) > 0 {
token = strings.TrimPrefix(authzHeader[0], "Bearer ")
} else {
c.JSON(403, gin.H{"error": "Permission Denied: Missing token"})
}

privKey, err := pelican_config.GetOriginJWK()
if err != nil {
c.JSON(400, gin.H{"error": "Failed to retrieve private key"})
return
}

var raw ecdsa.PrivateKey
if err = (*privKey).Raw(&raw); err != nil {
c.JSON(400, gin.H{"error": "Failed to extract signing key"})
return
}

if err != nil {
c.JSON(400, gin.H{"error": "Private Key Retrieval Failed"})
return
}

_, err = jwt.Parse([]byte(token), jwt.WithKey(jwa.ES512, raw.PublicKey), jwt.WithValidate(true))

if err != nil {
c.JSON(403, gin.H{"error": "Permission Denied: Invalid token"})
return
}

av1.ServeHTTP(c.Writer, c.Request)
}
}

func ConfigureEmbeddedPrometheus(engine *gin.Engine) error {

cfg := flagConfig{}
Expand Down Expand Up @@ -341,7 +389,7 @@ func ConfigureEmbeddedPrometheus(engine *gin.Engine) error {
//WithInstrumentation(setPathWithPrefix("/api/v1"))
apiV1.Register(av1)

engine.GET("/api/v1.0/prometheus/*any", gin.WrapH(av1))
engine.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))

reloaders := []reloader{
{
Expand Down
183 changes: 183 additions & 0 deletions web_ui/prometheus_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package web_ui

import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pelicanplatform/pelican/config"
"github.com/prometheus/common/route"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
)

func TestPrometheusProtection(t *testing.T) {

/*
* Tests that prometheus metrics are behind the origin's token. Specifically it signs a token
* with the origin's keyand invokes a prometheus GET endpoint with both URL and Header authorization,
* it then does so again with an invalid token and confirms that the correct error is returned
*/

// Setup httptest recorder and context for the the unit test
viper.Reset()

av1 := route.New().WithPrefix("/api/v1.0/prometheus")

w := httptest.NewRecorder()
c, r := gin.CreateTestContext(w)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, "GET", req.Method, "Not GET Method")
_, err := w.Write([]byte(":)"))
assert.NoError(t, err)
}))
defer ts.Close()
c.Request = &http.Request{
URL: &url.URL{},
}

// Create temp dir for the origin key file
tDir := t.TempDir()
kfile := filepath.Join(tDir, "testKey")

//Setup a private key and a token
viper.Set("IssuerKey", kfile)

viper.Set("NamespaceURL", "https://get-your-tokens.org")
viper.Set("DirectorURL", "https://director-url.org")

// Generate the origin private and public keys
_, err := config.LoadPublicKey("", kfile)

if err != nil {
t.Fatal(err)
}

privKey, err := config.LoadPrivateKey(kfile)
if err != nil {
t.Fatal(err)
}

// Create a token
issuerURL := url.URL{}
issuerURL.Scheme = "https"
issuerURL.Host = "test-host"
now := time.Now()
tok, err := jwt.NewBuilder().
Issuer(issuerURL.String()).
IssuedAt(now).
Expiration(now.Add(30 * time.Minute)).
NotBefore(now).
Build()

if err != nil {
t.Fatal(err)
}

// Sign the token with the origin private key
signed, err := jwt.Sign(tok, jwt.WithKey(jwa.ES512, privKey))

if err != nil {
t.Fatal(err)
}

// Set the request to run through the checkPromToken function
r.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))
c.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`)))

// Puts the token within the URL
new_query := c.Request.URL.Query()
new_query.Add("authz", string(signed))
c.Request.URL.RawQuery = new_query.Encode()

r.ServeHTTP(w, c.Request)

// Check to see that the code exits with status code 404 after giving it a good token
assert.Equal(t, 404, w.Result().StatusCode, "Expected status code of 404 representing failure due to minimal server setup, not token check")

// Create a new Recorder and Context for the next HTTPtest call
wH := httptest.NewRecorder()
cH, rH := gin.CreateTestContext(wH)
tsH := httptest.NewServer(http.HandlerFunc(func(wH http.ResponseWriter, req *http.Request) {
assert.Equal(t, "GET", req.Method, "Not GET Method")
_, err := wH.Write([]byte(":)"))
assert.NoError(t, err)
}))
defer tsH.Close()

// Set the request to go through the checkPromToken function
rH.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))
cH.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`)))

// Put the signed token within the header
cH.Request.Header.Set("Authorization", "Bearer "+string(signed))
cH.Request.Header.Set("Content-Type", "application/json")

rH.ServeHTTP(wH, cH.Request)
// Check to see that the code exits with status code 404 after given it a good token
assert.Equal(t, 404, wH.Result().StatusCode, "Expected status code of 404 representing failure due to minimal server setup, not token check")

// Create a new Recorder and Context for testing an invalid token
wI := httptest.NewRecorder()
cI, rI := gin.CreateTestContext(wI)
tsI := httptest.NewServer(http.HandlerFunc(func(wI http.ResponseWriter, req *http.Request) {
assert.Equal(t, "GET", req.Method, "Not GET Method")
_, err := wI.Write([]byte(":)"))
assert.NoError(t, err)
}))
defer tsI.Close()
c.Request = &http.Request{
URL: &url.URL{},
}

// Create a private key to use for the test
privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
assert.NoError(t, err, "Error generating private key")

// Convert from raw ecdsa to jwk.Key
pKey, err := jwk.FromRaw(privateKey)
assert.NoError(t, err, "Unable to convert ecdsa.PrivateKey to jwk.Key")

//Assign Key id to the private key
err = jwk.AssignKeyID(pKey)
assert.NoError(t, err, "Error assigning kid to private key")

//Set an algorithm for the key
err = pKey.Set(jwk.AlgorithmKey, jwa.ES512)
assert.NoError(t, err, "Unable to set algorithm for pKey")

// Create a new token to be used
tok, err = jwt.NewBuilder().
Issuer(issuerURL.String()).
IssuedAt(now).
Expiration(now.Add(30 * time.Minute)).
NotBefore(now).
Build()

assert.NoError(t, err, "Error creating token")

// Sign token with private key (not the origin)
signed, err = jwt.Sign(tok, jwt.WithKey(jwa.ES512, pKey))
assert.NoError(t, err, "Error signing token")

rI.GET("/api/v1.0/prometheus/*any", checkPromToken(av1))
cI.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`)))

cI.Request.Header.Set("Authorization", "Bearer "+string(signed))
cI.Request.Header.Set("Content-Type", "application/json")

rI.ServeHTTP(wI, cI.Request)
// Assert that it gets the correct Permission Denied 403 code
assert.Equal(t, 403, wI.Result().StatusCode, "Expected failing status code of 403: Permission Denied")
}

0 comments on commit 7177259

Please sign in to comment.