From da820484a340bedccfb8fab48b3e566383a15efe Mon Sep 17 00:00:00 2001 From: Bin Xia Date: Sat, 5 Aug 2023 13:31:24 +0000 Subject: [PATCH] Use retrable http client in Azure authz provider Signed-off-by: Bin Xia --- auth/providers/azure/azure.go | 40 ++----------- authz/providers/azure/azure.go | 10 +++- authz/providers/azure/azure_test.go | 87 +++++++++++++++++++++++++---- authz/providers/azure/rbac/rbac.go | 10 +++- util/azure/utils.go | 40 +++++++++++++ 5 files changed, 134 insertions(+), 53 deletions(-) diff --git a/auth/providers/azure/azure.go b/auth/providers/azure/azure.go index 386709990..9e1781f94 100644 --- a/auth/providers/azure/azure.go +++ b/auth/providers/azure/azure.go @@ -21,22 +21,19 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "strings" "sync" - "time" "go.kubeguard.dev/guard/auth" "go.kubeguard.dev/guard/auth/providers/azure/graph" - "go.kubeguard.dev/guard/util/httpclient" + azureutils "go.kubeguard.dev/guard/util/azure" "github.com/Azure/go-autorest/autorest/azure" "github.com/coreos/go-oidc" "github.com/golang-jwt/jwt/v4" "github.com/hashicorp/go-retryablehttp" "github.com/pkg/errors" - "golang.org/x/oauth2" authv1 "k8s.io/api/authentication/v1" "k8s.io/klog/v2" ) @@ -120,7 +117,7 @@ func getOIDCIssuerProvider(issuerURL string, issuerGetRetryCount int) (*oidc.Pro // NOTE: we start a root context here to allow background remote key set refresh ctx := context.Background() - ctx = withRetryableHttpClient(ctx, issuerGetRetryCount) + ctx = azureutils.WithRetryableHttpClient(ctx, issuerGetRetryCount) provider, err := oidc.NewProvider(ctx, issuerURL) if err != nil { // failed in this attempt, let other attempts retry @@ -180,35 +177,6 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) { return c, nil } -// makeRetryableHttpClient creates an HTTP client which attempts the request -// (1 + retryCount) times and has a 3 second timeout per attempt. -func makeRetryableHttpClient(retryCount int) retryablehttp.Client { - // Copy the default HTTP client so we can set a timeout. - // (It uses the same transport since the pointer gets copied) - httpClient := *httpclient.DefaultHTTPClient - httpClient.Timeout = 3 * time.Second - - // Attempt the request up to 3 times - return retryablehttp.Client{ - HTTPClient: &httpClient, - RetryWaitMin: 500 * time.Millisecond, - RetryWaitMax: 2 * time.Second, - RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts - CheckRetry: retryablehttp.DefaultRetryPolicy, - Backoff: retryablehttp.DefaultBackoff, - Logger: log.Default(), - } -} - -// withRetryableHttpClient sets the oauth2.HTTPClient key of the context to an -// *http.Client made from makeRetryableHttpClient. -// Some of the libraries we use will take the client out of the context via -// oauth2.HTTPClient and use it, so this way we can add retries to external code. -func withRetryableHttpClient(ctx context.Context, retryCount int) context.Context { - retryClient := makeRetryableHttpClient(retryCount) - return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient()) -} - type metadataJSON struct { Issuer string `json:"issuer"` MsgraphHost string `json:"msgraph_host"` @@ -217,7 +185,7 @@ type metadataJSON struct { // https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant func getMetadata(ctx context.Context, aadEndpoint, tenantID string, retryCount int) (*metadataJSON, error) { metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration" - retryClient := makeRetryableHttpClient(retryCount) + retryClient := azureutils.MakeRetryableHttpClient(retryCount) request, err := retryablehttp.NewRequest("GET", metadataURL, nil) if err != nil { @@ -261,7 +229,7 @@ func (s Authenticator) Check(ctx context.Context, token string) (*authv1.UserInf } } - ctx = withRetryableHttpClient(ctx, s.HttpClientRetryCount) + ctx = azureutils.WithRetryableHttpClient(ctx, s.HttpClientRetryCount) idToken, err := s.verifier.Verify(ctx, token) if err != nil { if klog.V(7).Enabled() { diff --git a/authz/providers/azure/azure.go b/authz/providers/azure/azure.go index dbfa4bf7b..e6d72bac5 100644 --- a/authz/providers/azure/azure.go +++ b/authz/providers/azure/azure.go @@ -26,6 +26,7 @@ import ( "go.kubeguard.dev/guard/authz" authzOpts "go.kubeguard.dev/guard/authz/providers/azure/options" "go.kubeguard.dev/guard/authz/providers/azure/rbac" + azureutils "go.kubeguard.dev/guard/util/azure" errutils "go.kubeguard.dev/guard/util/error" "github.com/Azure/go-autorest/autorest/azure" @@ -49,7 +50,8 @@ func init() { } type Authorizer struct { - rbacClient *rbac.AccessInfo + rbacClient *rbac.AccessInfo + httpClientRetryCount int } func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) { @@ -64,7 +66,9 @@ func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) } func newAuthzClient(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) { - c := &Authorizer{} + c := &Authorizer{ + httpClientRetryCount: authopts.HttpClientRetryCount, + } authzInfoVal, err := getAuthzInfo(authopts.Environment) if err != nil { @@ -120,6 +124,8 @@ func (s Authorizer) Check(ctx context.Context, request *authzv1.SubjectAccessRev return &authzv1.SubjectAccessReviewStatus{Allowed: true, Reason: rbac.AccessAllowedVerdict}, nil } + ctx = azureutils.WithRetryableHttpClient(ctx, s.httpClientRetryCount) + if s.rbacClient.IsTokenExpired() { if err := s.rbacClient.RefreshToken(ctx); err != nil { return nil, errutils.WithCode(err, http.StatusInternalServerError) diff --git a/authz/providers/azure/azure_test.go b/authz/providers/azure/azure_test.go index 7dc2bb4d6..9e4a799cd 100644 --- a/authz/providers/azure/azure_test.go +++ b/authz/providers/azure/azure_test.go @@ -18,9 +18,13 @@ package azure import ( "context" + "fmt" + "io/fs" "net" "net/http" "net/http/httptest" + "os" + "strconv" "testing" "time" @@ -32,12 +36,14 @@ import ( errutils "go.kubeguard.dev/guard/util/error" "github.com/go-chi/chi/v5" + "github.com/google/uuid" "github.com/stretchr/testify/assert" authzv1 "k8s.io/api/authorization/v1" ) const ( - loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}` + loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}` + httpClientRetryCount = 2 ) func clientSetup(serverUrl, mode string) (*Authorizer, error) { @@ -52,9 +58,10 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) { } authOpts := auth.Options{ - ClientID: "client_id", - ClientSecret: "client_secret", - TenantID: "tenant_id", + ClientID: "client_id", + ClientSecret: "client_secret", + TenantID: "tenant_id", + HttpClientRetryCount: httpClientRetryCount, } authzInfo := rbac.AuthzInfo{ @@ -70,7 +77,7 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) { return c, nil } -func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStatus int, sleepFor time.Duration) (*httptest.Server, error) { +func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStatus int, sleepFor time.Duration, calledTimesFile string) (*httptest.Server, error) { listener, err := net.Listen("tcp", "127.0.0.1:") if err != nil { return nil, err @@ -85,6 +92,9 @@ func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStat m.Post("/arm/*", func(w http.ResponseWriter, r *http.Request) { time.Sleep(sleepFor) + if calledTimesFile != "" { + _ = incCalledTimes(calledTimesFile) + } w.WriteHeader(checkaccessStatus) _, _ = w.Write([]byte(checkaccessResp)) }) @@ -98,8 +108,8 @@ func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStat return srv, nil } -func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkaccessStatus int, sleepFor time.Duration) (*httptest.Server, *Authorizer, authz.Store) { - srv, err := serverSetup(loginResp, checkaccessResp, http.StatusOK, checkaccessStatus, sleepFor) +func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkaccessStatus int, sleepFor time.Duration, calledTimesFile string) (*httptest.Server, *Authorizer, authz.Store) { // nolint: unparam + srv, err := serverSetup(loginResp, checkaccessResp, http.StatusOK, checkaccessStatus, sleepFor, calledTimesFile) if err != nil { t.Fatalf("Error when creating server, reason: %v", err) } @@ -123,13 +133,32 @@ func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkac return srv, client, dataStore } +func createCalledTimesFile() (string, error) { + calledTimesFile := uuid.New().String() + err := os.WriteFile(calledTimesFile, []byte(strconv.Itoa(0)), fs.ModeTemporary) + if err != nil { + return "", err + } + return calledTimesFile, nil +} + +func incCalledTimes(calledTimesFile string) error { + content, _ := os.ReadFile(calledTimesFile) + calledTimes, _ := strconv.Atoi(string(content)) + return os.WriteFile(calledTimesFile, []byte(strconv.Itoa(calledTimes+1)), fs.ModeTemporary) +} + +func deleteCalledTimesFile(calledTimesFile string) error { + return os.Remove(calledTimesFile) +} + func TestCheck(t *testing.T) { t.Run("successful request", func(t *testing.T) { validBody := `[{"accessDecision":"Allowed", "actionId":"Microsoft.Kubernetes/connectedClusters/pods/delete", "isDataAction":true,"roleAssignment":null,"denyAssignment":null,"timeToLiveInMs":300000}]` - srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusOK, 1*time.Second) + srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusOK, 1*time.Second, "") defer srv.Close() defer store.Close() @@ -154,7 +183,7 @@ func TestCheck(t *testing.T) { t.Run("unsuccessful request", func(t *testing.T) { validBody := `""` - srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second) + srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second, "") defer srv.Close() defer store.Close() @@ -170,15 +199,49 @@ func TestCheck(t *testing.T) { resp, err := client.Check(ctx, request, store) assert.Nilf(t, resp, "response should be nil") assert.NotNilf(t, err, "should get error") - assert.Contains(t, err.Error(), "Error occured during authorization check") + assert.Contains(t, err.Error(), "Error occured during authorization checkdfdf") if v, ok := err.(errutils.HttpStatusCode); ok { assert.Equal(t, v.Code(), http.StatusInternalServerError) } }) + t.Run("unsuccessful request - check retry count", func(t *testing.T) { + calledTimesFile, err := createCalledTimesFile() + assert.Nilf(t, err, "Should not have got error") + + validBody := `""` + srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second, calledTimesFile) + defer srv.Close() + defer store.Close() + + request := &authzv1.SubjectAccessReviewSpec{ + User: "beta@bing.com", + ResourceAttributes: &authzv1.ResourceAttributes{ + Namespace: "dev", Group: "", Resource: "pods", + Subresource: "status", Version: "v1", Name: "test", Verb: "delete", + }, Extra: map[string]authzv1.ExtraValue{"oid": {"00000000-0000-0000-0000-000000000000"}}, + } + + ctx := context.Background() + resp, err := client.Check(ctx, request, store) + assert.Nilf(t, resp, "response should be nil") + assert.NotNilf(t, err, "should get error") + assert.Contains(t, err.Error(), "Error occured during authorization checkdfdf") + if v, ok := err.(errutils.HttpStatusCode); ok { + assert.Equal(t, v.Code(), http.StatusInternalServerError) + } + + content, _ := os.ReadFile(calledTimesFile) + calledTimes, _ := strconv.Atoi(string(content)) + assert.Equal(t, httpClientRetryCount+1, calledTimes, fmt.Sprintf("The server should be called %d times", httpClientRetryCount+1)) + + err = deleteCalledTimesFile(calledTimesFile) + assert.Nilf(t, err, "Should not have got error") + }) + t.Run("context timeout request", func(t *testing.T) { validBody := `""` - srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 25*time.Second) + srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 25*time.Second, "") defer srv.Close() defer store.Close() @@ -194,7 +257,7 @@ func TestCheck(t *testing.T) { resp, err := client.Check(ctx, request, store) assert.Nilf(t, resp, "response should be nil") assert.NotNilf(t, err, "should get error") - assert.Contains(t, err.Error(), "Checkaccess requests have timed out") + assert.Contains(t, err.Error(), "context deadline exceeded") if v, ok := err.(errutils.HttpStatusCode); ok { assert.Equal(t, v.Code(), http.StatusInternalServerError) } diff --git a/authz/providers/azure/rbac/rbac.go b/authz/providers/azure/rbac/rbac.go index 50227096d..88c726d96 100644 --- a/authz/providers/azure/rbac/rbac.go +++ b/authz/providers/azure/rbac/rbac.go @@ -86,6 +86,7 @@ type AccessInfo struct { skipAuthzForNonAADUsers bool allowNonResDiscoveryPathAccess bool useNamespaceResourceScopeFormat bool + httpClientRetryCount int lock sync.RWMutex } @@ -155,7 +156,7 @@ func getClusterType(clsType string) string { } } -func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options) (*AccessInfo, error) { +func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options, authopts auth.Options) (*AccessInfo, error) { u := &AccessInfo{ client: httpclient.DefaultHTTPClient, headers: http.Header{ @@ -169,6 +170,7 @@ func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts aut skipAuthzForNonAADUsers: opts.SkipAuthzForNonAADUsers, allowNonResDiscoveryPathAccess: opts.AllowNonResDiscoveryPathAccess, useNamespaceResourceScopeFormat: opts.UseNamespaceResourceScopeFormat, + httpClientRetryCount: authopts.HttpClientRetryCount, } u.skipCheck = make(map[string]void, len(opts.SkipAuthzCheck)) @@ -207,7 +209,7 @@ func New(opts authzOpts.Options, authopts auth.Options, authzInfo *AuthzInfo) (* tokenProvider = graph.NewAKSTokenProvider(opts.AKSAuthzTokenURL, authopts.TenantID) } - return newAccessInfo(tokenProvider, rbacURL, opts) + return newAccessInfo(tokenProvider, rbacURL, opts, authopts) } func (a *AccessInfo) RefreshToken(ctx context.Context) error { @@ -328,6 +330,7 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut // create a request id for every checkaccess request requestUUID := uuid.New() reqContext := context.WithValue(egCtx, correlationRequestIDKey(correlationRequestIDHeader), []string{requestUUID.String()}) + reqContext = azureutils.WithRetryableHttpClient(reqContext, a.httpClientRetryCount) err := a.sendCheckAccessRequest(reqContext, checkAccessUsername, checkAccessURL, body, ch) if err != nil { code := http.StatusInternalServerError @@ -397,7 +400,8 @@ func (a *AccessInfo) sendCheckAccessRequest(ctx context.Context, checkAccessUser // start time to calculate checkaccess duration start := time.Now() klog.V(5).Infof("Sending checkAccess request with correlationID: %s", correlationID[0]) - resp, err := a.client.Do(req) + client := azureutils.LoadClientWithContext(ctx, a.client) + resp, err := client.Do(req) duration := time.Since(start).Seconds() if err != nil { checkAccessTotal.WithLabelValues(internalServerCode).Inc() diff --git a/util/azure/utils.go b/util/azure/utils.go index f0de92042..9ce878be9 100644 --- a/util/azure/utils.go +++ b/util/azure/utils.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log" "net/http" "path" "strconv" @@ -32,9 +33,11 @@ import ( "go.kubeguard.dev/guard/util/httpclient" "github.com/Azure/go-autorest/autorest/azure" + "github.com/hashicorp/go-retryablehttp" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" + "golang.org/x/oauth2" v "gomodules.xyz/x/version" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" @@ -502,6 +505,43 @@ func fetchDataActionsList(ctx context.Context) ([]Operation, error) { return finalOperations, nil } +// MakeRetryableHttpClient creates an HTTP client which attempts the request +// (1 + retryCount) times and has a 3 second timeout per attempt. +func MakeRetryableHttpClient(retryCount int) retryablehttp.Client { + // Copy the default HTTP client so we can set a timeout. + // (It uses the same transport since the pointer gets copied) + httpClient := *httpclient.DefaultHTTPClient + httpClient.Timeout = 3 * time.Second + + // Attempt the request up to 3 times + return retryablehttp.Client{ + HTTPClient: &httpClient, + RetryWaitMin: 500 * time.Millisecond, + RetryWaitMax: 2 * time.Second, + RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts + CheckRetry: retryablehttp.DefaultRetryPolicy, + Backoff: retryablehttp.DefaultBackoff, + Logger: log.Default(), + } +} + +// WithRetryableHttpClient sets the oauth2.HTTPClient key of the context to an +// *http.Client made from makeRetryableHttpClient. +// Some of the libraries we use will take the client out of the context via +// oauth2.HTTPClient and use it, so this way we can add retries to external code. +func WithRetryableHttpClient(ctx context.Context, retryCount int) context.Context { + retryClient := MakeRetryableHttpClient(retryCount) + return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient()) +} + +func LoadClientWithContext(ctx context.Context, defaultClient *http.Client) *http.Client { + if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { + return c + } + + return defaultClient +} + func init() { prometheus.MustRegister(DiscoverResourcesTotalDuration, discoverResourcesAzureCallDuration, discoverResourcesApiServerCallDuration, counterDiscoverResources, counterGetOperationsResources) }