diff --git a/auth/providers/azure/azure.go b/auth/providers/azure/azure.go index 536370c8c..124b761d9 100644 --- a/auth/providers/azure/azure.go +++ b/auth/providers/azure/azure.go @@ -21,20 +21,17 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "strings" - "time" "go.kubeguard.dev/guard/auth" "go.kubeguard.dev/guard/auth/providers/azure/graph" - "go.kubeguard.dev/guard/util/httpclient" + azureutils "go.kubeguard.dev/guard/util/azure" "github.com/Azure/go-autorest/autorest/azure" "github.com/coreos/go-oidc" "github.com/hashicorp/go-retryablehttp" "github.com/pkg/errors" - "golang.org/x/oauth2" authv1 "k8s.io/api/authentication/v1" "k8s.io/klog/v2" ) @@ -90,7 +87,7 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) { klog.V(3).Infof("Using issuer url: %v", authInfoVal.Issuer) - ctx = withRetryableHttpClient(ctx, c.HttpClientRetryCount) + ctx = azureutils.WithRetryableHttpClient(ctx, c.HttpClientRetryCount) provider, err := oidc.NewProvider(ctx, authInfoVal.Issuer) if err != nil { return nil, errors.Wrap(err, "failed to create provider for azure") @@ -117,35 +114,6 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) { return c, nil } -// makeRetryableHttpClient creates an HTTP client which attempts the request -// (1 + retryCount) times and has a 3 second timeout per attempt. -func makeRetryableHttpClient(retryCount int) retryablehttp.Client { - // Copy the default HTTP client so we can set a timeout. - // (It uses the same transport since the pointer gets copied) - httpClient := *httpclient.DefaultHTTPClient - httpClient.Timeout = 3 * time.Second - - // Attempt the request up to 3 times - return retryablehttp.Client{ - HTTPClient: &httpClient, - RetryWaitMin: 500 * time.Millisecond, - RetryWaitMax: 2 * time.Second, - RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts - CheckRetry: retryablehttp.DefaultRetryPolicy, - Backoff: retryablehttp.DefaultBackoff, - Logger: log.Default(), - } -} - -// withRetryableHttpClient sets the oauth2.HTTPClient key of the context to an -// *http.Client made from makeRetryableHttpClient. -// Some of the libraries we use will take the client out of the context via -// oauth2.HTTPClient and use it, so this way we can add retries to external code. -func withRetryableHttpClient(ctx context.Context, retryCount int) context.Context { - retryClient := makeRetryableHttpClient(retryCount) - return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient()) -} - type metadataJSON struct { Issuer string `json:"issuer"` MsgraphHost string `json:"msgraph_host"` @@ -154,7 +122,7 @@ type metadataJSON struct { // https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant func getMetadata(ctx context.Context, aadEndpoint, tenantID string, retryCount int) (*metadataJSON, error) { metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration" - retryClient := makeRetryableHttpClient(retryCount) + retryClient := azureutils.MakeRetryableHttpClient(retryCount) request, err := retryablehttp.NewRequest("GET", metadataURL, nil) if err != nil { @@ -198,7 +166,7 @@ func (s Authenticator) Check(ctx context.Context, token string) (*authv1.UserInf } } - ctx = withRetryableHttpClient(ctx, s.HttpClientRetryCount) + ctx = azureutils.WithRetryableHttpClient(ctx, s.HttpClientRetryCount) idToken, err := s.verifier.Verify(ctx, token) if err != nil { return nil, errors.Wrap(err, "failed to verify token for azure") diff --git a/authz/providers/azure/azure.go b/authz/providers/azure/azure.go index dbfa4bf7b..e6d72bac5 100644 --- a/authz/providers/azure/azure.go +++ b/authz/providers/azure/azure.go @@ -26,6 +26,7 @@ import ( "go.kubeguard.dev/guard/authz" authzOpts "go.kubeguard.dev/guard/authz/providers/azure/options" "go.kubeguard.dev/guard/authz/providers/azure/rbac" + azureutils "go.kubeguard.dev/guard/util/azure" errutils "go.kubeguard.dev/guard/util/error" "github.com/Azure/go-autorest/autorest/azure" @@ -49,7 +50,8 @@ func init() { } type Authorizer struct { - rbacClient *rbac.AccessInfo + rbacClient *rbac.AccessInfo + httpClientRetryCount int } func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) { @@ -64,7 +66,9 @@ func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) } func newAuthzClient(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) { - c := &Authorizer{} + c := &Authorizer{ + httpClientRetryCount: authopts.HttpClientRetryCount, + } authzInfoVal, err := getAuthzInfo(authopts.Environment) if err != nil { @@ -120,6 +124,8 @@ func (s Authorizer) Check(ctx context.Context, request *authzv1.SubjectAccessRev return &authzv1.SubjectAccessReviewStatus{Allowed: true, Reason: rbac.AccessAllowedVerdict}, nil } + ctx = azureutils.WithRetryableHttpClient(ctx, s.httpClientRetryCount) + if s.rbacClient.IsTokenExpired() { if err := s.rbacClient.RefreshToken(ctx); err != nil { return nil, errutils.WithCode(err, http.StatusInternalServerError) diff --git a/authz/providers/azure/azure_test.go b/authz/providers/azure/azure_test.go index 7dc2bb4d6..7c08f8671 100644 --- a/authz/providers/azure/azure_test.go +++ b/authz/providers/azure/azure_test.go @@ -37,7 +37,8 @@ import ( ) const ( - loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}` + loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}` + httpClientRetryCount = 2 ) func clientSetup(serverUrl, mode string) (*Authorizer, error) { @@ -52,9 +53,10 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) { } authOpts := auth.Options{ - ClientID: "client_id", - ClientSecret: "client_secret", - TenantID: "tenant_id", + ClientID: "client_id", + ClientSecret: "client_secret", + TenantID: "tenant_id", + HttpClientRetryCount: httpClientRetryCount, } authzInfo := rbac.AuthzInfo{ diff --git a/authz/providers/azure/rbac/rbac.go b/authz/providers/azure/rbac/rbac.go index 8f771f625..af5608baf 100644 --- a/authz/providers/azure/rbac/rbac.go +++ b/authz/providers/azure/rbac/rbac.go @@ -86,6 +86,7 @@ type AccessInfo struct { skipAuthzForNonAADUsers bool allowNonResDiscoveryPathAccess bool useNamespaceResourceScopeFormat bool + httpClientRetryCount int lock sync.RWMutex } @@ -155,7 +156,7 @@ func getClusterType(clsType string) string { } } -func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options) (*AccessInfo, error) { +func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options, authopts auth.Options) (*AccessInfo, error) { u := &AccessInfo{ client: httpclient.DefaultHTTPClient, headers: http.Header{ @@ -169,6 +170,7 @@ func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts aut skipAuthzForNonAADUsers: opts.SkipAuthzForNonAADUsers, allowNonResDiscoveryPathAccess: opts.AllowNonResDiscoveryPathAccess, useNamespaceResourceScopeFormat: opts.UseNamespaceResourceScopeFormat, + httpClientRetryCount: authopts.HttpClientRetryCount, } u.skipCheck = make(map[string]void, len(opts.SkipAuthzCheck)) @@ -207,7 +209,7 @@ func New(opts authzOpts.Options, authopts auth.Options, authzInfo *AuthzInfo) (* tokenProvider = graph.NewAKSTokenProvider(opts.AKSAuthzTokenURL, authopts.TenantID) } - return newAccessInfo(tokenProvider, rbacURL, opts) + return newAccessInfo(tokenProvider, rbacURL, opts, authopts) } func (a *AccessInfo) RefreshToken(ctx context.Context) error { @@ -326,6 +328,7 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut // create a request id for every checkaccess request requestUUID := uuid.New() reqContext := context.WithValue(egCtx, correlationRequestIDKey(correlationRequestIDHeader), []string{requestUUID.String()}) + reqContext = azureutils.WithRetryableHttpClient(reqContext, a.httpClientRetryCount) err := a.sendCheckAccessRequest(reqContext, checkAccessURL, body, ch) if err != nil { code := http.StatusInternalServerError diff --git a/util/azure/utils.go b/util/azure/utils.go index f0de92042..a1c13b49d 100644 --- a/util/azure/utils.go +++ b/util/azure/utils.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log" "net/http" "path" "strconv" @@ -32,9 +33,11 @@ import ( "go.kubeguard.dev/guard/util/httpclient" "github.com/Azure/go-autorest/autorest/azure" + "github.com/hashicorp/go-retryablehttp" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" + "golang.org/x/oauth2" v "gomodules.xyz/x/version" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" @@ -502,6 +505,35 @@ func fetchDataActionsList(ctx context.Context) ([]Operation, error) { return finalOperations, nil } +// MakeRetryableHttpClient creates an HTTP client which attempts the request +// (1 + retryCount) times and has a 3 second timeout per attempt. +func MakeRetryableHttpClient(retryCount int) retryablehttp.Client { + // Copy the default HTTP client so we can set a timeout. + // (It uses the same transport since the pointer gets copied) + httpClient := *httpclient.DefaultHTTPClient + httpClient.Timeout = 3 * time.Second + + // Attempt the request up to 3 times + return retryablehttp.Client{ + HTTPClient: &httpClient, + RetryWaitMin: 500 * time.Millisecond, + RetryWaitMax: 2 * time.Second, + RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts + CheckRetry: retryablehttp.DefaultRetryPolicy, + Backoff: retryablehttp.DefaultBackoff, + Logger: log.Default(), + } +} + +// WithRetryableHttpClient sets the oauth2.HTTPClient key of the context to an +// *http.Client made from makeRetryableHttpClient. +// Some of the libraries we use will take the client out of the context via +// oauth2.HTTPClient and use it, so this way we can add retries to external code. +func WithRetryableHttpClient(ctx context.Context, retryCount int) context.Context { + retryClient := MakeRetryableHttpClient(retryCount) + return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient()) +} + func init() { prometheus.MustRegister(DiscoverResourcesTotalDuration, discoverResourcesAzureCallDuration, discoverResourcesApiServerCallDuration, counterDiscoverResources, counterGetOperationsResources) }