Skip to content

Commit

Permalink
Use retrable http client in Azure authz provider
Browse files Browse the repository at this point in the history
Signed-off-by: Bin Xia <binxi@microsoft.com>
  • Loading branch information
bingosummer committed Aug 7, 2023
1 parent 68273e5 commit ad88d63
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 44 deletions.
40 changes: 4 additions & 36 deletions auth/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,17 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"

"go.kubeguard.dev/guard/auth"
"go.kubeguard.dev/guard/auth/providers/azure/graph"
"go.kubeguard.dev/guard/util/httpclient"
azureutils "go.kubeguard.dev/guard/util/azure"

"github.com/Azure/go-autorest/autorest/azure"
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
"golang.org/x/oauth2"
authv1 "k8s.io/api/authentication/v1"
"k8s.io/klog/v2"
)
Expand Down Expand Up @@ -90,7 +87,7 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) {

klog.V(3).Infof("Using issuer url: %v", authInfoVal.Issuer)

ctx = withRetryableHttpClient(ctx, c.HttpClientRetryCount)
ctx = azureutils.WithRetryableHttpClient(ctx, c.HttpClientRetryCount)
provider, err := oidc.NewProvider(ctx, authInfoVal.Issuer)
if err != nil {
return nil, errors.Wrap(err, "failed to create provider for azure")
Expand All @@ -117,35 +114,6 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) {
return c, nil
}

// makeRetryableHttpClient creates an HTTP client which attempts the request
// (1 + retryCount) times and has a 3 second timeout per attempt.
func makeRetryableHttpClient(retryCount int) retryablehttp.Client {
// Copy the default HTTP client so we can set a timeout.
// (It uses the same transport since the pointer gets copied)
httpClient := *httpclient.DefaultHTTPClient
httpClient.Timeout = 3 * time.Second

// Attempt the request up to 3 times
return retryablehttp.Client{
HTTPClient: &httpClient,
RetryWaitMin: 500 * time.Millisecond,
RetryWaitMax: 2 * time.Second,
RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
Logger: log.Default(),
}
}

// withRetryableHttpClient sets the oauth2.HTTPClient key of the context to an
// *http.Client made from makeRetryableHttpClient.
// Some of the libraries we use will take the client out of the context via
// oauth2.HTTPClient and use it, so this way we can add retries to external code.
func withRetryableHttpClient(ctx context.Context, retryCount int) context.Context {
retryClient := makeRetryableHttpClient(retryCount)
return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient())
}

type metadataJSON struct {
Issuer string `json:"issuer"`
MsgraphHost string `json:"msgraph_host"`
Expand All @@ -154,7 +122,7 @@ type metadataJSON struct {
// https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant
func getMetadata(ctx context.Context, aadEndpoint, tenantID string, retryCount int) (*metadataJSON, error) {
metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration"
retryClient := makeRetryableHttpClient(retryCount)
retryClient := azureutils.MakeRetryableHttpClient(retryCount)

request, err := retryablehttp.NewRequest("GET", metadataURL, nil)
if err != nil {
Expand Down Expand Up @@ -198,7 +166,7 @@ func (s Authenticator) Check(ctx context.Context, token string) (*authv1.UserInf
}
}

ctx = withRetryableHttpClient(ctx, s.HttpClientRetryCount)
ctx = azureutils.WithRetryableHttpClient(ctx, s.HttpClientRetryCount)
idToken, err := s.verifier.Verify(ctx, token)
if err != nil {
return nil, errors.Wrap(err, "failed to verify token for azure")
Expand Down
10 changes: 8 additions & 2 deletions authz/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"go.kubeguard.dev/guard/authz"
authzOpts "go.kubeguard.dev/guard/authz/providers/azure/options"
"go.kubeguard.dev/guard/authz/providers/azure/rbac"
azureutils "go.kubeguard.dev/guard/util/azure"
errutils "go.kubeguard.dev/guard/util/error"

"github.com/Azure/go-autorest/autorest/azure"
Expand All @@ -49,7 +50,8 @@ func init() {
}

type Authorizer struct {
rbacClient *rbac.AccessInfo
rbacClient *rbac.AccessInfo
HttpClientRetryCount int
}

func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) {
Expand All @@ -64,7 +66,9 @@ func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error)
}

func newAuthzClient(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) {
c := &Authorizer{}
c := &Authorizer{
HttpClientRetryCount: authopts.HttpClientRetryCount,
}

authzInfoVal, err := getAuthzInfo(authopts.Environment)
if err != nil {
Expand Down Expand Up @@ -120,6 +124,8 @@ func (s Authorizer) Check(ctx context.Context, request *authzv1.SubjectAccessRev
return &authzv1.SubjectAccessReviewStatus{Allowed: true, Reason: rbac.AccessAllowedVerdict}, nil
}

ctx = azureutils.WithRetryableHttpClient(ctx, s.HttpClientRetryCount)

if s.rbacClient.IsTokenExpired() {
if err := s.rbacClient.RefreshToken(ctx); err != nil {
return nil, errutils.WithCode(err, http.StatusInternalServerError)
Expand Down
10 changes: 6 additions & 4 deletions authz/providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ import (
)

const (
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}`
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}`
httpClientRetryCount = 2
)

func clientSetup(serverUrl, mode string) (*Authorizer, error) {
Expand All @@ -52,9 +53,10 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) {
}

authOpts := auth.Options{
ClientID: "client_id",
ClientSecret: "client_secret",
TenantID: "tenant_id",
ClientID: "client_id",
ClientSecret: "client_secret",
TenantID: "tenant_id",
HttpClientRetryCount: httpClientRetryCount,
}

authzInfo := rbac.AuthzInfo{
Expand Down
7 changes: 5 additions & 2 deletions authz/providers/azure/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type AccessInfo struct {
skipAuthzForNonAADUsers bool
allowNonResDiscoveryPathAccess bool
useNamespaceResourceScopeFormat bool
httpClientRetryCount int
lock sync.RWMutex
}

Expand Down Expand Up @@ -155,7 +156,7 @@ func getClusterType(clsType string) string {
}
}

func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options) (*AccessInfo, error) {
func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options, authopts auth.Options) (*AccessInfo, error) {
u := &AccessInfo{
client: httpclient.DefaultHTTPClient,
headers: http.Header{
Expand All @@ -169,6 +170,7 @@ func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts aut
skipAuthzForNonAADUsers: opts.SkipAuthzForNonAADUsers,
allowNonResDiscoveryPathAccess: opts.AllowNonResDiscoveryPathAccess,
useNamespaceResourceScopeFormat: opts.UseNamespaceResourceScopeFormat,
httpClientRetryCount: authopts.HttpClientRetryCount,
}

u.skipCheck = make(map[string]void, len(opts.SkipAuthzCheck))
Expand Down Expand Up @@ -207,7 +209,7 @@ func New(opts authzOpts.Options, authopts auth.Options, authzInfo *AuthzInfo) (*
tokenProvider = graph.NewAKSTokenProvider(opts.AKSAuthzTokenURL, authopts.TenantID)
}

return newAccessInfo(tokenProvider, rbacURL, opts)
return newAccessInfo(tokenProvider, rbacURL, opts, authopts)
}

func (a *AccessInfo) RefreshToken(ctx context.Context) error {
Expand Down Expand Up @@ -326,6 +328,7 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut
// create a request id for every checkaccess request
requestUUID := uuid.New()
reqContext := context.WithValue(egCtx, correlationRequestIDKey(correlationRequestIDHeader), []string{requestUUID.String()})
reqContext = azureutils.WithRetryableHttpClient(reqContext, a.httpClientRetryCount)
err := a.sendCheckAccessRequest(reqContext, checkAccessURL, body, ch)
if err != nil {
code := http.StatusInternalServerError
Expand Down
32 changes: 32 additions & 0 deletions util/azure/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"io"
"log"
"net/http"
"path"
"strconv"
Expand All @@ -32,9 +33,11 @@ import (
"go.kubeguard.dev/guard/util/httpclient"

"github.com/Azure/go-autorest/autorest/azure"
"github.com/hashicorp/go-retryablehttp"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/oauth2"
v "gomodules.xyz/x/version"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
Expand Down Expand Up @@ -502,6 +505,35 @@ func fetchDataActionsList(ctx context.Context) ([]Operation, error) {
return finalOperations, nil
}

// MakeRetryableHttpClient creates an HTTP client which attempts the request
// (1 + retryCount) times and has a 3 second timeout per attempt.
func MakeRetryableHttpClient(retryCount int) retryablehttp.Client {
// Copy the default HTTP client so we can set a timeout.
// (It uses the same transport since the pointer gets copied)
httpClient := *httpclient.DefaultHTTPClient
httpClient.Timeout = 3 * time.Second

// Attempt the request up to 3 times
return retryablehttp.Client{
HTTPClient: &httpClient,
RetryWaitMin: 500 * time.Millisecond,
RetryWaitMax: 2 * time.Second,
RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
Logger: log.Default(),
}
}

// WithRetryableHttpClient sets the oauth2.HTTPClient key of the context to an
// *http.Client made from makeRetryableHttpClient.
// Some of the libraries we use will take the client out of the context via
// oauth2.HTTPClient and use it, so this way we can add retries to external code.
func WithRetryableHttpClient(ctx context.Context, retryCount int) context.Context {
retryClient := MakeRetryableHttpClient(retryCount)
return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient())
}

func init() {
prometheus.MustRegister(DiscoverResourcesTotalDuration, discoverResourcesAzureCallDuration, discoverResourcesApiServerCallDuration, counterDiscoverResources, counterGetOperationsResources)
}

0 comments on commit ad88d63

Please sign in to comment.