Skip to content

Commit

Permalink
Add unit tests, only discover origins, more labels
Browse files Browse the repository at this point in the history
  • Loading branch information
haoming29 committed Oct 27, 2023
1 parent 1f7ebde commit c9aa107
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 1 deletion.
10 changes: 9 additions & 1 deletion director/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,19 @@ func DiscoverOrigins(ctx *gin.Context) {
serverAds := serverAds.Keys()
promDiscoveryRes := make([]PromDiscoveryItem, 0)
for _, ad := range serverAds {
// We don't include caches in this discovery for right now
if ad.Type != OriginType {
continue
}
promDiscoveryRes = append(promDiscoveryRes, PromDiscoveryItem{
// TODO: change to ad.WebURL when #285 is ready
Targets: []string{ad.URL.Hostname() + ":" + ad.URL.Port()},
Labels: map[string]string{
"job": ad.Name,
"origin_name": ad.Name,
"origin_auth_url": ad.AuthURL.String(),
"origin_url": ad.URL.String(),
"origin_lat": fmt.Sprintf("%.4f", ad.Latitude),
"origin_long": fmt.Sprintf("%.4f", ad.Longitude),
},
})
}
Expand Down
161 changes: 161 additions & 0 deletions director/redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"time"

Expand All @@ -18,6 +21,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pelicanplatform/pelican/config"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -248,3 +252,160 @@ func TestGetAuthzEscaped(t *testing.T) {
escapedToken = getAuthzEscaped(req)
assert.Equal(t, escapedToken, "tokenstring")
}

func TestDiscoverOrigins(t *testing.T) {
mockOriginServerAd := ServerAd{
Name: "test-origin-server",
AuthURL: url.URL{},
URL: url.URL{
Scheme: "https",
Host: "fake-origin.org:8444",
},
Type: OriginType,
Latitude: 123.05,
Longitude: 456.78,
}

mockCacheServerAd := ServerAd{
Name: "test-cache-server",
AuthURL: url.URL{},
URL: url.URL{
Scheme: "https",
Host: "fake-cache.org:8444",
},
Type: CacheType,
Latitude: 45.67,
Longitude: 123.05,
}

mockNamespaceAd := NamespaceAd{
RequireToken: true,
Path: "/foo/bar/",
Issuer: url.URL{},
MaxScopeDepth: 1,
Strategy: "",
BasePath: "",
VaultServer: "",
}

mockDirectorUrl := "https://fake-director.org:8888"
viper.Reset()
viper.Set("Federation.DirectorUrl", mockDirectorUrl)

tDir := t.TempDir()
kfile := filepath.Join(tDir, "testKey")
viper.Set("IssuerKey", kfile)

// Generate a private key to use for the test
_, err := config.LoadPublicKey("", kfile)
assert.NoError(t, err, "Error generating private key")
// Get private key
privateKey, err := config.GetOriginJWK()
assert.NoError(t, err, "Error loading private key")

// Batch set up different tokens
setupToken := func(wrongIssuer string) []byte {
issuerURL, err := url.Parse(mockDirectorUrl)
assert.NoError(t, err, "Error parsing director's URL")
tokenIssuerString := ""
if wrongIssuer != "" {
tokenIssuerString = wrongIssuer
} else {
tokenIssuerString = issuerURL.String()
}

tok, err := jwt.NewBuilder().
Issuer(tokenIssuerString).
Claim("scope", "pelican.directorSD").
Audience([]string{"director.test"}).
Subject("director").
Expiration(time.Now().Add(time.Hour)).
Build()
assert.NoError(t, err, "Error creating token")

err = jwk.AssignKeyID(*privateKey)
assert.NoError(t, err, "Error assigning key id")

// Sign token with previously created private key
signed, err := jwt.Sign(tok, jwt.WithKey(jwa.ES256, *privateKey))
assert.NoError(t, err, "Error signing token")
return signed
}

r := gin.Default()
r.GET("/test", DiscoverOrigins)

t.Run("no-token-should-give-401", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/test", nil)
if err != nil {
t.Fatalf("Could not make a GET request: %v", err)
}

w := httptest.NewRecorder()
r.ServeHTTP(w, req)

assert.Equal(t, 401, w.Code)
assert.Equal(t, `{"error":"Bearer token not present in the 'Authorization' header"}`, w.Body.String())
})
t.Run("token-present-with-wrong-issuer-should-give-401", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/test", nil)
if err != nil {
t.Fatalf("Could not make a GET request: %v", err)
}

req.Header.Set("Authorization", "Bearer "+string(setupToken("https://wrong-issuer.org")))

w := httptest.NewRecorder()
r.ServeHTTP(w, req)

assert.Equal(t, 401, w.Code)
assert.Equal(t, `{"error":"Authorization token verification failed: Token issuer is not a director\n"}`, w.Body.String())
})
t.Run("token-present-valid-should-give-200-and-empty-array", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/test", nil)
if err != nil {
t.Fatalf("Could not make a GET request: %v", err)
}

req.Header.Set("Authorization", "Bearer "+string(setupToken("")))

w := httptest.NewRecorder()
r.ServeHTTP(w, req)

assert.Equal(t, 200, w.Code)
assert.Equal(t, `[]`, w.Body.String())
})
t.Run("response-origin-should-match-cache", func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/test", nil)
if err != nil {
t.Fatalf("Could not make a GET request: %v", err)
}

serverAdMutex.Lock()
serverAds.Set(mockOriginServerAd, []NamespaceAd{mockNamespaceAd}, time.Duration(10))
serverAds.Set(mockCacheServerAd, []NamespaceAd{mockNamespaceAd}, time.Duration(10))
serverAdMutex.Unlock()

expectedRes := []PromDiscoveryItem{{
Targets: []string{mockOriginServerAd.URL.Hostname() + ":" + mockOriginServerAd.URL.Port()},
Labels: map[string]string{
"origin_name": mockOriginServerAd.Name,
"origin_auth_url": mockOriginServerAd.AuthURL.String(),
"origin_url": mockOriginServerAd.URL.String(),
"origin_lat": fmt.Sprintf("%.4f", mockOriginServerAd.Latitude),
"origin_long": fmt.Sprintf("%.4f", mockOriginServerAd.Longitude),
},
}}

resStr, err := json.Marshal(expectedRes)
assert.NoError(t, err, "Could not marshal json response")

req.Header.Set("Authorization", "Bearer "+string(setupToken("")))

w := httptest.NewRecorder()
r.ServeHTTP(w, req)

assert.Equal(t, 200, w.Code)
assert.Equal(t, string(resStr), w.Body.String())
})
}

0 comments on commit c9aa107

Please sign in to comment.