From 717725988c004f84fa00a4c3c837a13f0361a50a Mon Sep 17 00:00:00 2001 From: Emma Turetsky Date: Wed, 27 Sep 2023 16:29:31 +0000 Subject: [PATCH] Added token checking for prometheus metrics -- Added a checkPromToken function in prometheus.go to check the tokens -- Added unit tests for checking the metrics -- Fixed missing "Bearer " in redirect_test --- director/redirect_test.go | 2 +- web_ui/prometheus.go | 50 ++++++++++- web_ui/prometheus_test.go | 183 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 web_ui/prometheus_test.go diff --git a/director/redirect_test.go b/director/redirect_test.go index ab566aabc..ddc16b8fc 100644 --- a/director/redirect_test.go +++ b/director/redirect_test.go @@ -181,7 +181,7 @@ func TestDirectorRegistration(t *testing.T) { rInv.POST("/", RegisterOrigin) cInv.Request, _ = http.NewRequest(http.MethodPost, "/", bytes.NewBuffer([]byte(`{"Namespaces": [{"Path": "/foo/bar", "URL": "https://get-your-tokens.org"}]}`))) - cInv.Request.Header.Set("Authorization", string(signedInv)) + cInv.Request.Header.Set("Authorization", "Bearer "+string(signedInv)) cInv.Request.Header.Set("Content-Type", "application/json") rInv.ServeHTTP(wInv, cInv.Request) diff --git a/web_ui/prometheus.go b/web_ui/prometheus.go index 414bd1f17..120537897 100644 --- a/web_ui/prometheus.go +++ b/web_ui/prometheus.go @@ -17,6 +17,7 @@ package web_ui import ( "context" + "crypto/ecdsa" "errors" "fmt" "math" @@ -35,6 +36,8 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/grafana/regexp" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/mwitkow/go-conntrack" "github.com/oklog/run" pelican_config "github.com/pelicanplatform/pelican/config" @@ -138,6 +141,51 @@ func runtimeInfo() (api_v1.RuntimeInfo, error) { return api_v1.RuntimeInfo{}, nil } +func checkPromToken(av1 *route.Router) gin.HandlerFunc { + /* A function which wraps around the av1 router to force a jwk token check using + * the origin's private key. It will check the request's URL and Header for a token + * and if found it will then attempt to validate the token. If valid, it will continue + * the routing as normal, otherwise it will return an error" + */ + return func(c *gin.Context) { + req := c.Request + var token string + if authzQuery := req.URL.Query()["authz"]; len(authzQuery) > 0 { + token = authzQuery[0] + } else if authzHeader := req.Header["Authorization"]; len(authzHeader) > 0 { + token = strings.TrimPrefix(authzHeader[0], "Bearer ") + } else { + c.JSON(403, gin.H{"error": "Permission Denied: Missing token"}) + } + + privKey, err := pelican_config.GetOriginJWK() + if err != nil { + c.JSON(400, gin.H{"error": "Failed to retrieve private key"}) + return + } + + var raw ecdsa.PrivateKey + if err = (*privKey).Raw(&raw); err != nil { + c.JSON(400, gin.H{"error": "Failed to extract signing key"}) + return + } + + if err != nil { + c.JSON(400, gin.H{"error": "Private Key Retrieval Failed"}) + return + } + + _, err = jwt.Parse([]byte(token), jwt.WithKey(jwa.ES512, raw.PublicKey), jwt.WithValidate(true)) + + if err != nil { + c.JSON(403, gin.H{"error": "Permission Denied: Invalid token"}) + return + } + + av1.ServeHTTP(c.Writer, c.Request) + } +} + func ConfigureEmbeddedPrometheus(engine *gin.Engine) error { cfg := flagConfig{} @@ -341,7 +389,7 @@ func ConfigureEmbeddedPrometheus(engine *gin.Engine) error { //WithInstrumentation(setPathWithPrefix("/api/v1")) apiV1.Register(av1) - engine.GET("/api/v1.0/prometheus/*any", gin.WrapH(av1)) + engine.GET("/api/v1.0/prometheus/*any", checkPromToken(av1)) reloaders := []reloader{ { diff --git a/web_ui/prometheus_test.go b/web_ui/prometheus_test.go new file mode 100644 index 000000000..f24eaa0d4 --- /dev/null +++ b/web_ui/prometheus_test.go @@ -0,0 +1,183 @@ +package web_ui + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/pelicanplatform/pelican/config" + "github.com/prometheus/common/route" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +func TestPrometheusProtection(t *testing.T) { + + /* + * Tests that prometheus metrics are behind the origin's token. Specifically it signs a token + * with the origin's keyand invokes a prometheus GET endpoint with both URL and Header authorization, + * it then does so again with an invalid token and confirms that the correct error is returned + */ + + // Setup httptest recorder and context for the the unit test + viper.Reset() + + av1 := route.New().WithPrefix("/api/v1.0/prometheus") + + w := httptest.NewRecorder() + c, r := gin.CreateTestContext(w) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + assert.Equal(t, "GET", req.Method, "Not GET Method") + _, err := w.Write([]byte(":)")) + assert.NoError(t, err) + })) + defer ts.Close() + c.Request = &http.Request{ + URL: &url.URL{}, + } + + // Create temp dir for the origin key file + tDir := t.TempDir() + kfile := filepath.Join(tDir, "testKey") + + //Setup a private key and a token + viper.Set("IssuerKey", kfile) + + viper.Set("NamespaceURL", "https://get-your-tokens.org") + viper.Set("DirectorURL", "https://director-url.org") + + // Generate the origin private and public keys + _, err := config.LoadPublicKey("", kfile) + + if err != nil { + t.Fatal(err) + } + + privKey, err := config.LoadPrivateKey(kfile) + if err != nil { + t.Fatal(err) + } + + // Create a token + issuerURL := url.URL{} + issuerURL.Scheme = "https" + issuerURL.Host = "test-host" + now := time.Now() + tok, err := jwt.NewBuilder(). + Issuer(issuerURL.String()). + IssuedAt(now). + Expiration(now.Add(30 * time.Minute)). + NotBefore(now). + Build() + + if err != nil { + t.Fatal(err) + } + + // Sign the token with the origin private key + signed, err := jwt.Sign(tok, jwt.WithKey(jwa.ES512, privKey)) + + if err != nil { + t.Fatal(err) + } + + // Set the request to run through the checkPromToken function + r.GET("/api/v1.0/prometheus/*any", checkPromToken(av1)) + c.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`))) + + // Puts the token within the URL + new_query := c.Request.URL.Query() + new_query.Add("authz", string(signed)) + c.Request.URL.RawQuery = new_query.Encode() + + r.ServeHTTP(w, c.Request) + + // Check to see that the code exits with status code 404 after giving it a good token + assert.Equal(t, 404, w.Result().StatusCode, "Expected status code of 404 representing failure due to minimal server setup, not token check") + + // Create a new Recorder and Context for the next HTTPtest call + wH := httptest.NewRecorder() + cH, rH := gin.CreateTestContext(wH) + tsH := httptest.NewServer(http.HandlerFunc(func(wH http.ResponseWriter, req *http.Request) { + assert.Equal(t, "GET", req.Method, "Not GET Method") + _, err := wH.Write([]byte(":)")) + assert.NoError(t, err) + })) + defer tsH.Close() + + // Set the request to go through the checkPromToken function + rH.GET("/api/v1.0/prometheus/*any", checkPromToken(av1)) + cH.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`))) + + // Put the signed token within the header + cH.Request.Header.Set("Authorization", "Bearer "+string(signed)) + cH.Request.Header.Set("Content-Type", "application/json") + + rH.ServeHTTP(wH, cH.Request) + // Check to see that the code exits with status code 404 after given it a good token + assert.Equal(t, 404, wH.Result().StatusCode, "Expected status code of 404 representing failure due to minimal server setup, not token check") + + // Create a new Recorder and Context for testing an invalid token + wI := httptest.NewRecorder() + cI, rI := gin.CreateTestContext(wI) + tsI := httptest.NewServer(http.HandlerFunc(func(wI http.ResponseWriter, req *http.Request) { + assert.Equal(t, "GET", req.Method, "Not GET Method") + _, err := wI.Write([]byte(":)")) + assert.NoError(t, err) + })) + defer tsI.Close() + c.Request = &http.Request{ + URL: &url.URL{}, + } + + // Create a private key to use for the test + privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + assert.NoError(t, err, "Error generating private key") + + // Convert from raw ecdsa to jwk.Key + pKey, err := jwk.FromRaw(privateKey) + assert.NoError(t, err, "Unable to convert ecdsa.PrivateKey to jwk.Key") + + //Assign Key id to the private key + err = jwk.AssignKeyID(pKey) + assert.NoError(t, err, "Error assigning kid to private key") + + //Set an algorithm for the key + err = pKey.Set(jwk.AlgorithmKey, jwa.ES512) + assert.NoError(t, err, "Unable to set algorithm for pKey") + + // Create a new token to be used + tok, err = jwt.NewBuilder(). + Issuer(issuerURL.String()). + IssuedAt(now). + Expiration(now.Add(30 * time.Minute)). + NotBefore(now). + Build() + + assert.NoError(t, err, "Error creating token") + + // Sign token with private key (not the origin) + signed, err = jwt.Sign(tok, jwt.WithKey(jwa.ES512, pKey)) + assert.NoError(t, err, "Error signing token") + + rI.GET("/api/v1.0/prometheus/*any", checkPromToken(av1)) + cI.Request, _ = http.NewRequest(http.MethodGet, "/api/v1.0/prometheus/test", bytes.NewBuffer([]byte(`{}`))) + + cI.Request.Header.Set("Authorization", "Bearer "+string(signed)) + cI.Request.Header.Set("Content-Type", "application/json") + + rI.ServeHTTP(wI, cI.Request) + // Assert that it gets the correct Permission Denied 403 code + assert.Equal(t, 403, wI.Result().StatusCode, "Expected failing status code of 403: Permission Denied") +}