From e1cc31a24ebecf3209befc38ce227b5a4a97b00f Mon Sep 17 00:00:00 2001 From: Justin Hiemstra Date: Mon, 23 Oct 2023 21:53:03 +0000 Subject: [PATCH] Fix broken Director auth headers and add regression test --- director/advertise.go | 90 ++++++++++---------------- director/advertise_test.go | 129 +++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 55 deletions(-) create mode 100644 director/advertise_test.go diff --git a/director/advertise.go b/director/advertise.go index 98708e705..b96972231 100644 --- a/director/advertise.go +++ b/director/advertise.go @@ -65,6 +65,38 @@ type ( } ) +func parseServerAd(server Server, serverType ServerType) ServerAd { + serverAd := ServerAd{} + serverAd.Type = serverType + serverAd.Name = server.Resource + + // url.Parse requires that the scheme be present before the hostname, + // but endpoints do not have a scheme. As such, we need to add one for the. + // correct parsing. Luckily, we don't use this anywhere else (it's just to + // make the url.Parse function behave as expected) + if !strings.HasPrefix(server.AuthEndpoint, "http") { // just in case there's already an http(s) tacked in front + server.AuthEndpoint = "https://" + server.AuthEndpoint + } + if !strings.HasPrefix(server.Endpoint, "http") { // just in case there's already an http(s) tacked in front + server.Endpoint = "http://" + server.Endpoint + } + serverAuthUrl, err := url.Parse(server.AuthEndpoint) + if err != nil { + log.Warningf("Namespace JSON returned server %s with invalid authenticated URL %s", + server.Resource, server.AuthEndpoint) + } + serverAd.AuthURL = *serverAuthUrl + + serverUrl, err := url.Parse(server.Endpoint) + if err != nil { + log.Warningf("Namespace JSON returned server %s with invalid unauthenticated URL %s", + server.Resource, server.Endpoint) + } + serverAd.URL = *serverUrl + + return serverAd +} + // Populate internal cache with origin/cache ads func AdvertiseOSDF() error { topoNamespaceUrl := param.Federation_TopologyNamespaceUrl.GetString() @@ -122,65 +154,13 @@ func AdvertiseOSDF() error { // they're listed as inactive by topology). These namespaces will all be mapped to the // same useless origin ad, resulting in a 404 for queries to those namespaces for _, origin := range ns.Origins { - originAd := ServerAd{} - originAd.Type = OriginType - originAd.Name = origin.Resource - // url.Parse requires that the scheme be present before the hostname, - // but endpoints do not have a scheme. As such, we need to add one for the. - // correct parsing. Luckily, we don't use this anywhere else (it's just to - // make the url.Parse function behave as expected) - if !strings.HasPrefix(origin.AuthEndpoint, "http") { // just in case there's already an http(s) tacked in front - origin.AuthEndpoint = "https://" + origin.AuthEndpoint - } - if !strings.HasPrefix(origin.Endpoint, "http") { // just in case there's already an http(s) tacked in front - origin.Endpoint = "http://" + origin.Endpoint - } - originAuthURL, err := url.Parse(origin.AuthEndpoint) - if err != nil { - log.Warningf("Namespace JSON returned origin %s with invalid authenticated URL %s", - origin.Resource, origin.AuthEndpoint) - } - originAd.AuthURL = *originAuthURL - originURL, err := url.Parse(origin.Endpoint) - if err != nil { - log.Warningf("Namespace JSON returned origin %s with invalid unauthenticated URL %s", - origin.Resource, origin.Endpoint) - } - originAd.URL = *originURL - + originAd := parseServerAd(origin, OriginType) originAdMap[originAd] = append(originAdMap[originAd], nsAd) } for _, cache := range ns.Caches { - cacheAd := ServerAd{} - cacheAd.Type = CacheType - cacheAd.Name = cache.Resource - - if !strings.HasPrefix(cache.AuthEndpoint, "http") { // just in case there's already an http(s) tacked in front - cache.AuthEndpoint = "https://" + cache.AuthEndpoint - } - if !strings.HasPrefix(cache.Endpoint, "http") { // just in case there's already an http(s) tacked in front - cache.Endpoint = "http://" + cache.Endpoint - } - cacheAuthURL, err := url.Parse(cache.AuthEndpoint) - if err != nil { - log.Warningf("Namespace JSON returned cache %s with invalid authenticated URL %s", - cache.Resource, cache.AuthEndpoint) - } - cacheAd.AuthURL = *cacheAuthURL - - cacheURL, err := url.Parse(cache.Endpoint) - if err != nil { - log.Warningf("Namespace JSON returned cache %s with invalid unauthenticated URL %s", - cache.Resource, cache.Endpoint) - } - cacheAd.URL = *cacheURL - - cacheNS := NamespaceAd{} - cacheNS.Path = ns.Path - cacheNS.RequireToken = ns.UseTokenOnRead - cacheAdMap[cacheAd] = append(cacheAdMap[cacheAd], cacheNS) - + cacheAd := parseServerAd(cache, CacheType) + cacheAdMap[cacheAd] = append(cacheAdMap[cacheAd], nsAd) } } diff --git a/director/advertise_test.go b/director/advertise_test.go new file mode 100644 index 000000000..7cfd5bda8 --- /dev/null +++ b/director/advertise_test.go @@ -0,0 +1,129 @@ +package director + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +func TestParseServerAd(t *testing.T) { + + server := Server{ + AuthEndpoint: "https://my-auth-endpoint.com", + Endpoint: "http://my-endpoint.com", + Resource: "MY_SERVER", + } + + // Check that we populate all of the fields correctly -- note that lat/long don't get updated + // until right before the ad is recorded, so we don't check for that here. + ad := parseServerAd(server, OriginType) + assert.Equal(t, ad.AuthURL.String(), "https://my-auth-endpoint.com") + assert.Equal(t, ad.URL.String(), "http://my-endpoint.com") + assert.Equal(t, ad.Name, "MY_SERVER") + assert.True(t, ad.Type == OriginType) + + // A quick check that type is set correctly + ad = parseServerAd(server, CacheType) + assert.True(t, ad.Type == CacheType) +} + +func JSONHandler(w http.ResponseWriter, r *http.Request) { + jsonResponse := ` + { + "caches": [ + { + "auth_endpoint": "https://cache-auth-endpoint.com", + "endpoint": "http://cache-endpoint.com", + "resource": "MY_CACHE" + } + ], + "namespaces": [ + { + "caches": [ + { + "auth_endpoint": "https://cache-auth-endpoint.com", + "endpoint": "http://cache-endpoint.com", + "resource": "MY_CACHE" + } + ], + "credential_generation": { + "base_path": "/server", + "issuer": "https://my-issuer.com", + "max_scope_depth": 3, + "strategy": "OAuth2", + "vault_issuer": null, + "vault_server": null + }, + "dirlisthost": null, + "origins": [ + { + "auth_endpoint": "https://origin1-auth-endpoint.com", + "endpoint": "http://origin1-endpoint.com", + "resource": "MY_ORIGIN1" + } + ], + "path": "/my/server", + "readhttps": true, + "usetokenonread": true, + "writebackhost": "https://writeback.my-server.com" + }, + { + "caches": [ + { + "auth_endpoint": "https://cache-auth-endpoint.com", + "endpoint": "http://cache-endpoint.com", + "resource": "MY_CACHE" + } + ], + "credential_generation": null, + "dirlisthost": null, + "origins": [ + { + "auth_endpoint": "https://origin2-auth-endpoint.com", + "endpoint": "http://origin2-endpoint.com", + "resource": "MY_ORIGIN2" + } + ], + "path": "/my/server/2", + "readhttps": true, + "usetokenonread": false, + "writebackhost": null + } + ] + } + ` + + // Set the Content-Type header to indicate JSON. + w.Header().Set("Content-Type", "application/json") + + // Write the JSON response to the response body. + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(jsonResponse)) +} +func TestAdvertiseOSDF(t *testing.T) { + viper.Reset() + topoServer := httptest.NewServer(http.HandlerFunc(JSONHandler)) + defer topoServer.Close() + viper.Set("Federation.TopologyNamespaceUrl", topoServer.URL) + + err := AdvertiseOSDF() + if err != nil { + t.Fatal(err) + } + + // Test a few values. If they're correct, it indicates the whole process likely succeeded + nsAd, oAds, cAds := GetAdsForPath("/my/server/path/to/file") + assert.Equal(t, nsAd.Path, "/my/server") + assert.Equal(t, nsAd.MaxScopeDepth, uint(3)) + assert.Equal(t, oAds[0].AuthURL.String(), "https://origin1-auth-endpoint.com") + assert.Equal(t, cAds[0].URL.String(), "http://cache-endpoint.com") + + nsAd, oAds, cAds = GetAdsForPath("/my/server/2/path/to/file") + assert.Equal(t, nsAd.Path, "/my/server/2") + assert.Equal(t, nsAd.RequireToken, false) + assert.Equal(t, oAds[0].AuthURL.String(), "https://origin2-auth-endpoint.com") + assert.Equal(t, cAds[0].URL.String(), "http://cache-endpoint.com") +}