Skip to content

Commit

Permalink
Use retrable http client in Azure authz provider
Browse files Browse the repository at this point in the history
Signed-off-by: Bin Xia <binxi@microsoft.com>
  • Loading branch information
bingosummer committed Apr 1, 2024
1 parent 286709a commit da82048
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 53 deletions.
40 changes: 4 additions & 36 deletions auth/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,19 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"

"go.kubeguard.dev/guard/auth"
"go.kubeguard.dev/guard/auth/providers/azure/graph"
"go.kubeguard.dev/guard/util/httpclient"
azureutils "go.kubeguard.dev/guard/util/azure"

"github.com/Azure/go-autorest/autorest/azure"
"github.com/coreos/go-oidc"
"github.com/golang-jwt/jwt/v4"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
"golang.org/x/oauth2"
authv1 "k8s.io/api/authentication/v1"
"k8s.io/klog/v2"
)
Expand Down Expand Up @@ -120,7 +117,7 @@ func getOIDCIssuerProvider(issuerURL string, issuerGetRetryCount int) (*oidc.Pro

// NOTE: we start a root context here to allow background remote key set refresh
ctx := context.Background()
ctx = withRetryableHttpClient(ctx, issuerGetRetryCount)
ctx = azureutils.WithRetryableHttpClient(ctx, issuerGetRetryCount)
provider, err := oidc.NewProvider(ctx, issuerURL)
if err != nil {
// failed in this attempt, let other attempts retry
Expand Down Expand Up @@ -180,35 +177,6 @@ func New(ctx context.Context, opts Options) (auth.Interface, error) {
return c, nil
}

// makeRetryableHttpClient creates an HTTP client which attempts the request
// (1 + retryCount) times and has a 3 second timeout per attempt.
func makeRetryableHttpClient(retryCount int) retryablehttp.Client {
// Copy the default HTTP client so we can set a timeout.
// (It uses the same transport since the pointer gets copied)
httpClient := *httpclient.DefaultHTTPClient
httpClient.Timeout = 3 * time.Second

// Attempt the request up to 3 times
return retryablehttp.Client{
HTTPClient: &httpClient,
RetryWaitMin: 500 * time.Millisecond,
RetryWaitMax: 2 * time.Second,
RetryMax: retryCount, // initial + retryCount retries = (1 + retryCount) attempts
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
Logger: log.Default(),
}
}

// withRetryableHttpClient sets the oauth2.HTTPClient key of the context to an
// *http.Client made from makeRetryableHttpClient.
// Some of the libraries we use will take the client out of the context via
// oauth2.HTTPClient and use it, so this way we can add retries to external code.
func withRetryableHttpClient(ctx context.Context, retryCount int) context.Context {
retryClient := makeRetryableHttpClient(retryCount)
return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient())
}

type metadataJSON struct {
Issuer string `json:"issuer"`
MsgraphHost string `json:"msgraph_host"`
Expand All @@ -217,7 +185,7 @@ type metadataJSON struct {
// https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant
func getMetadata(ctx context.Context, aadEndpoint, tenantID string, retryCount int) (*metadataJSON, error) {
metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration"
retryClient := makeRetryableHttpClient(retryCount)
retryClient := azureutils.MakeRetryableHttpClient(retryCount)

request, err := retryablehttp.NewRequest("GET", metadataURL, nil)
if err != nil {
Expand Down Expand Up @@ -261,7 +229,7 @@ func (s Authenticator) Check(ctx context.Context, token string) (*authv1.UserInf
}
}

ctx = withRetryableHttpClient(ctx, s.HttpClientRetryCount)
ctx = azureutils.WithRetryableHttpClient(ctx, s.HttpClientRetryCount)
idToken, err := s.verifier.Verify(ctx, token)
if err != nil {
if klog.V(7).Enabled() {
Expand Down
10 changes: 8 additions & 2 deletions authz/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"go.kubeguard.dev/guard/authz"
authzOpts "go.kubeguard.dev/guard/authz/providers/azure/options"
"go.kubeguard.dev/guard/authz/providers/azure/rbac"
azureutils "go.kubeguard.dev/guard/util/azure"
errutils "go.kubeguard.dev/guard/util/error"

"github.com/Azure/go-autorest/autorest/azure"
Expand All @@ -49,7 +50,8 @@ func init() {
}

type Authorizer struct {
rbacClient *rbac.AccessInfo
rbacClient *rbac.AccessInfo
httpClientRetryCount int
}

func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) {
Expand All @@ -64,7 +66,9 @@ func New(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error)
}

func newAuthzClient(opts authzOpts.Options, authopts auth.Options) (authz.Interface, error) {
c := &Authorizer{}
c := &Authorizer{
httpClientRetryCount: authopts.HttpClientRetryCount,
}

authzInfoVal, err := getAuthzInfo(authopts.Environment)
if err != nil {
Expand Down Expand Up @@ -120,6 +124,8 @@ func (s Authorizer) Check(ctx context.Context, request *authzv1.SubjectAccessRev
return &authzv1.SubjectAccessReviewStatus{Allowed: true, Reason: rbac.AccessAllowedVerdict}, nil
}

ctx = azureutils.WithRetryableHttpClient(ctx, s.httpClientRetryCount)

if s.rbacClient.IsTokenExpired() {
if err := s.rbacClient.RefreshToken(ctx); err != nil {
return nil, errutils.WithCode(err, http.StatusInternalServerError)
Expand Down
87 changes: 75 additions & 12 deletions authz/providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ package azure

import (
"context"
"fmt"
"io/fs"
"net"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"

Expand All @@ -32,12 +36,14 @@ import (
errutils "go.kubeguard.dev/guard/util/error"

"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
authzv1 "k8s.io/api/authorization/v1"
)

const (
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}`
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}`
httpClientRetryCount = 2
)

func clientSetup(serverUrl, mode string) (*Authorizer, error) {
Expand All @@ -52,9 +58,10 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) {
}

authOpts := auth.Options{
ClientID: "client_id",
ClientSecret: "client_secret",
TenantID: "tenant_id",
ClientID: "client_id",
ClientSecret: "client_secret",
TenantID: "tenant_id",
HttpClientRetryCount: httpClientRetryCount,
}

authzInfo := rbac.AuthzInfo{
Expand All @@ -70,7 +77,7 @@ func clientSetup(serverUrl, mode string) (*Authorizer, error) {
return c, nil
}

func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStatus int, sleepFor time.Duration) (*httptest.Server, error) {
func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStatus int, sleepFor time.Duration, calledTimesFile string) (*httptest.Server, error) {
listener, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return nil, err
Expand All @@ -85,6 +92,9 @@ func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStat

m.Post("/arm/*", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(sleepFor)
if calledTimesFile != "" {
_ = incCalledTimes(calledTimesFile)
}
w.WriteHeader(checkaccessStatus)
_, _ = w.Write([]byte(checkaccessResp))
})
Expand All @@ -98,8 +108,8 @@ func serverSetup(loginResp, checkaccessResp string, loginStatus, checkaccessStat
return srv, nil
}

func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkaccessStatus int, sleepFor time.Duration) (*httptest.Server, *Authorizer, authz.Store) {
srv, err := serverSetup(loginResp, checkaccessResp, http.StatusOK, checkaccessStatus, sleepFor)
func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkaccessStatus int, sleepFor time.Duration, calledTimesFile string) (*httptest.Server, *Authorizer, authz.Store) { // nolint: unparam
srv, err := serverSetup(loginResp, checkaccessResp, http.StatusOK, checkaccessStatus, sleepFor, calledTimesFile)
if err != nil {
t.Fatalf("Error when creating server, reason: %v", err)
}
Expand All @@ -123,13 +133,32 @@ func getServerAndClient(t *testing.T, loginResp, checkaccessResp string, checkac
return srv, client, dataStore
}

func createCalledTimesFile() (string, error) {
calledTimesFile := uuid.New().String()
err := os.WriteFile(calledTimesFile, []byte(strconv.Itoa(0)), fs.ModeTemporary)
if err != nil {
return "", err
}
return calledTimesFile, nil
}

func incCalledTimes(calledTimesFile string) error {
content, _ := os.ReadFile(calledTimesFile)
calledTimes, _ := strconv.Atoi(string(content))
return os.WriteFile(calledTimesFile, []byte(strconv.Itoa(calledTimes+1)), fs.ModeTemporary)
}

func deleteCalledTimesFile(calledTimesFile string) error {
return os.Remove(calledTimesFile)
}

func TestCheck(t *testing.T) {
t.Run("successful request", func(t *testing.T) {
validBody := `[{"accessDecision":"Allowed",
"actionId":"Microsoft.Kubernetes/connectedClusters/pods/delete",
"isDataAction":true,"roleAssignment":null,"denyAssignment":null,"timeToLiveInMs":300000}]`

srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusOK, 1*time.Second)
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusOK, 1*time.Second, "")
defer srv.Close()
defer store.Close()

Expand All @@ -154,7 +183,7 @@ func TestCheck(t *testing.T) {

t.Run("unsuccessful request", func(t *testing.T) {
validBody := `""`
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second)
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second, "")
defer srv.Close()
defer store.Close()

Expand All @@ -170,15 +199,49 @@ func TestCheck(t *testing.T) {
resp, err := client.Check(ctx, request, store)
assert.Nilf(t, resp, "response should be nil")
assert.NotNilf(t, err, "should get error")
assert.Contains(t, err.Error(), "Error occured during authorization check")
assert.Contains(t, err.Error(), "Error occured during authorization checkdfdf")
if v, ok := err.(errutils.HttpStatusCode); ok {
assert.Equal(t, v.Code(), http.StatusInternalServerError)
}
})

t.Run("unsuccessful request - check retry count", func(t *testing.T) {
calledTimesFile, err := createCalledTimesFile()
assert.Nilf(t, err, "Should not have got error")

validBody := `""`
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 1*time.Second, calledTimesFile)
defer srv.Close()
defer store.Close()

request := &authzv1.SubjectAccessReviewSpec{
User: "beta@bing.com",
ResourceAttributes: &authzv1.ResourceAttributes{
Namespace: "dev", Group: "", Resource: "pods",
Subresource: "status", Version: "v1", Name: "test", Verb: "delete",
}, Extra: map[string]authzv1.ExtraValue{"oid": {"00000000-0000-0000-0000-000000000000"}},
}

ctx := context.Background()
resp, err := client.Check(ctx, request, store)
assert.Nilf(t, resp, "response should be nil")
assert.NotNilf(t, err, "should get error")
assert.Contains(t, err.Error(), "Error occured during authorization checkdfdf")
if v, ok := err.(errutils.HttpStatusCode); ok {
assert.Equal(t, v.Code(), http.StatusInternalServerError)
}

content, _ := os.ReadFile(calledTimesFile)
calledTimes, _ := strconv.Atoi(string(content))
assert.Equal(t, httpClientRetryCount+1, calledTimes, fmt.Sprintf("The server should be called %d times", httpClientRetryCount+1))

err = deleteCalledTimesFile(calledTimesFile)
assert.Nilf(t, err, "Should not have got error")
})

t.Run("context timeout request", func(t *testing.T) {
validBody := `""`
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 25*time.Second)
srv, client, store := getServerAndClient(t, loginResp, validBody, http.StatusInternalServerError, 25*time.Second, "")
defer srv.Close()
defer store.Close()

Expand All @@ -194,7 +257,7 @@ func TestCheck(t *testing.T) {
resp, err := client.Check(ctx, request, store)
assert.Nilf(t, resp, "response should be nil")
assert.NotNilf(t, err, "should get error")
assert.Contains(t, err.Error(), "Checkaccess requests have timed out")
assert.Contains(t, err.Error(), "context deadline exceeded")
if v, ok := err.(errutils.HttpStatusCode); ok {
assert.Equal(t, v.Code(), http.StatusInternalServerError)
}
Expand Down
10 changes: 7 additions & 3 deletions authz/providers/azure/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type AccessInfo struct {
skipAuthzForNonAADUsers bool
allowNonResDiscoveryPathAccess bool
useNamespaceResourceScopeFormat bool
httpClientRetryCount int
lock sync.RWMutex
}

Expand Down Expand Up @@ -155,7 +156,7 @@ func getClusterType(clsType string) string {
}
}

func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options) (*AccessInfo, error) {
func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts authzOpts.Options, authopts auth.Options) (*AccessInfo, error) {
u := &AccessInfo{
client: httpclient.DefaultHTTPClient,
headers: http.Header{
Expand All @@ -169,6 +170,7 @@ func newAccessInfo(tokenProvider graph.TokenProvider, rbacURL *url.URL, opts aut
skipAuthzForNonAADUsers: opts.SkipAuthzForNonAADUsers,
allowNonResDiscoveryPathAccess: opts.AllowNonResDiscoveryPathAccess,
useNamespaceResourceScopeFormat: opts.UseNamespaceResourceScopeFormat,
httpClientRetryCount: authopts.HttpClientRetryCount,
}

u.skipCheck = make(map[string]void, len(opts.SkipAuthzCheck))
Expand Down Expand Up @@ -207,7 +209,7 @@ func New(opts authzOpts.Options, authopts auth.Options, authzInfo *AuthzInfo) (*
tokenProvider = graph.NewAKSTokenProvider(opts.AKSAuthzTokenURL, authopts.TenantID)
}

return newAccessInfo(tokenProvider, rbacURL, opts)
return newAccessInfo(tokenProvider, rbacURL, opts, authopts)
}

func (a *AccessInfo) RefreshToken(ctx context.Context) error {
Expand Down Expand Up @@ -328,6 +330,7 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut
// create a request id for every checkaccess request
requestUUID := uuid.New()
reqContext := context.WithValue(egCtx, correlationRequestIDKey(correlationRequestIDHeader), []string{requestUUID.String()})
reqContext = azureutils.WithRetryableHttpClient(reqContext, a.httpClientRetryCount)
err := a.sendCheckAccessRequest(reqContext, checkAccessUsername, checkAccessURL, body, ch)
if err != nil {
code := http.StatusInternalServerError
Expand Down Expand Up @@ -397,7 +400,8 @@ func (a *AccessInfo) sendCheckAccessRequest(ctx context.Context, checkAccessUser
// start time to calculate checkaccess duration
start := time.Now()
klog.V(5).Infof("Sending checkAccess request with correlationID: %s", correlationID[0])
resp, err := a.client.Do(req)
client := azureutils.LoadClientWithContext(ctx, a.client)
resp, err := client.Do(req)
duration := time.Since(start).Seconds()
if err != nil {
checkAccessTotal.WithLabelValues(internalServerCode).Inc()
Expand Down
Loading

0 comments on commit da82048

Please sign in to comment.