diff --git a/auth/providers/azure/graph/graph.go b/auth/providers/azure/graph/graph.go index e20df458d..6c9d34549 100644 --- a/auth/providers/azure/graph/graph.go +++ b/auth/providers/azure/graph/graph.go @@ -408,8 +408,8 @@ func newUserInfo(tokenProvider TokenProvider, graphURL *url.URL, useGroupUID boo // New returns a new UserInfo object func New(clientID, clientSecret, tenantID string, useGroupUID bool, aadEndpoint, msgraphHost string) (*UserInfo, error) { - graphEndpoint := "https://" + msgraphHost + "/" - graphURL, _ := url.Parse(graphEndpoint + "v1.0") + graphEndpoint := "https://" + msgraphHost + "/" //nolint:goconst // expected url building + graphURL, _ := url.Parse(graphEndpoint + "v1.0") //nolint:goconst // expected url building tokenProvider := NewClientCredentialTokenProvider(clientID, clientSecret, fmt.Sprintf("%s%s/oauth2/v2.0/token", aadEndpoint, tenantID), @@ -420,8 +420,8 @@ func New(clientID, clientSecret, tenantID string, useGroupUID bool, aadEndpoint, // NewWithOBO returns a new UserInfo object func NewWithOBO(clientID, clientSecret, tenantID string, aadEndpoint, msgraphHost string) (*UserInfo, error) { - graphEndpoint := "https://" + msgraphHost + "/" - graphURL, _ := url.Parse(graphEndpoint + "v1.0") + graphEndpoint := "https://" + msgraphHost + "/" //nolint:goconst // expected url building + graphURL, _ := url.Parse(graphEndpoint + "v1.0") //nolint:goconst // expected url building tokenProvider := NewOBOTokenProvider(clientID, clientSecret, fmt.Sprintf("%s%s/oauth2/v2.0/token", aadEndpoint, tenantID), @@ -432,8 +432,8 @@ func NewWithOBO(clientID, clientSecret, tenantID string, aadEndpoint, msgraphHos // NewWithAKS returns a new UserInfo object used in AKS func NewWithAKS(tokenURL, tenantID, msgraphHost string) (*UserInfo, error) { - graphEndpoint := "https://" + msgraphHost + "/" - graphURL, _ := url.Parse(graphEndpoint + "v1.0") + graphEndpoint := "https://" + msgraphHost + "/" //nolint:goconst // expected url building + graphURL, _ := url.Parse(graphEndpoint + "v1.0") //nolint:goconst // expected url building tokenProvider := NewAKSTokenProvider(tokenURL, tenantID) diff --git a/authz/providers/azure/rbac/checkaccessreqhelper.go b/authz/providers/azure/rbac/checkaccessreqhelper.go index 8ef5c0312..b54a6ca12 100644 --- a/authz/providers/azure/rbac/checkaccessreqhelper.go +++ b/authz/providers/azure/rbac/checkaccessreqhelper.go @@ -47,10 +47,7 @@ const ( PodsResource = "pods" ) -var ( - username string - getStoredOperationsMap = azureutils.DeepCopyOperationsMap -) +var getStoredOperationsMap = azureutils.DeepCopyOperationsMap type SubjectInfoAttributes struct { ObjectId string `json:"ObjectId"` @@ -506,7 +503,6 @@ func prepareCheckAccessRequestBody(req *authzv1.SubjectAccessReviewSpec, cluster return nil, errutils.WithCode(errors.New("oid info not sent from authentication module"), http.StatusBadRequest) } groups := getValidSecurityGroups(req.Groups) - username = req.User actions, err := getDataActions(req, clusterType) if err != nil { return nil, errutils.WithCode(errors.Wrap(err, "Error while creating list of dataactions for check access call"), http.StatusInternalServerError) @@ -542,7 +538,7 @@ func getNameSpaceScope(req *authzv1.SubjectAccessReviewSpec, useNamespaceResourc return false, namespace } -func ConvertCheckAccessResponse(body []byte) (*authzv1.SubjectAccessReviewStatus, error) { +func ConvertCheckAccessResponse(username string, body []byte) (*authzv1.SubjectAccessReviewStatus, error) { var ( response []AuthorizationDecision allowed bool diff --git a/authz/providers/azure/rbac/rbac.go b/authz/providers/azure/rbac/rbac.go index 8f771f625..50227096d 100644 --- a/authz/providers/azure/rbac/rbac.go +++ b/authz/providers/azure/rbac/rbac.go @@ -298,6 +298,8 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut return nil, errors.Wrap(err, "error in preparing check access request") } + checkAccessUsername := request.User + checkAccessURL := *a.apiURL // Append the path for azure cluster resource id checkAccessURL.Path = path.Join(checkAccessURL.Path, a.azureResourceId) @@ -326,7 +328,7 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut // create a request id for every checkaccess request requestUUID := uuid.New() reqContext := context.WithValue(egCtx, correlationRequestIDKey(correlationRequestIDHeader), []string{requestUUID.String()}) - err := a.sendCheckAccessRequest(reqContext, checkAccessURL, body, ch) + err := a.sendCheckAccessRequest(reqContext, checkAccessUsername, checkAccessURL, body, ch) if err != nil { code := http.StatusInternalServerError if v, ok := err.(errutils.HttpStatusCode); ok { @@ -370,7 +372,7 @@ func (a *AccessInfo) CheckAccess(request *authzv1.SubjectAccessReviewSpec) (*aut return finalStatus, nil } -func (a *AccessInfo) sendCheckAccessRequest(ctx context.Context, checkAccessURL url.URL, checkAccessBody *CheckAccessRequest, ch chan *authzv1.SubjectAccessReviewStatus) error { +func (a *AccessInfo) sendCheckAccessRequest(ctx context.Context, checkAccessUsername string, checkAccessURL url.URL, checkAccessBody *CheckAccessRequest, ch chan *authzv1.SubjectAccessReviewStatus) error { buf := new(bytes.Buffer) if err := json.NewEncoder(buf).Encode(checkAccessBody); err != nil { return errutils.WithCode(errors.Wrap(err, "error encoding check access request"), http.StatusInternalServerError) @@ -447,7 +449,7 @@ func (a *AccessInfo) sendCheckAccessRequest(ctx context.Context, checkAccessURL } // Decode response and prepare k8s response - status, err := ConvertCheckAccessResponse(data) + status, err := ConvertCheckAccessResponse(checkAccessUsername, data) if err != nil { return err } diff --git a/authz/providers/azure/rbac/rbac_test.go b/authz/providers/azure/rbac/rbac_test.go index 851da0528..a5fd7eb5f 100644 --- a/authz/providers/azure/rbac/rbac_test.go +++ b/authz/providers/azure/rbac/rbac_test.go @@ -117,6 +117,45 @@ func TestCheckAccess(t *testing.T) { assert.Nilf(t, response, "response should be nil") assert.NotNilf(t, err, "should get error") }) + + t.Run("concurrent access to CheckAccess method", func(t *testing.T) { + validBody := `[{"accessDecision":"Allowed", + "actionId":"Microsoft.Kubernetes/connectedClusters/pods/delete", + "isDataAction":true,"roleAssignment":null,"denyAssignment":null,"timeToLiveInMs":300000}]` + + ts, u := getAPIServerAndAccessInfo(http.StatusOK, validBody, "aks", "aks-managed-cluster") + defer ts.Close() + + requestTimes := 5 + requests := []*authzv1.SubjectAccessReviewSpec{} + for i := 0; i < requestTimes; i++ { + requests = append( + requests, + &authzv1.SubjectAccessReviewSpec{ + User: fmt.Sprintf("user%d@bing.com", i), + ResourceAttributes: &authzv1.ResourceAttributes{ + Namespace: "dev", Group: "", Resource: "pods", + Subresource: "status", Version: "v1", Name: "test", Verb: "delete", + }, Extra: map[string]authzv1.ExtraValue{"oid": {"00000000-0000-0000-0000-000000000000"}}, + }, + ) + } + + wg := new(sync.WaitGroup) + for _, request := range requests { + wg.Add(1) + go func(request *authzv1.SubjectAccessReviewSpec) { + defer wg.Done() + response, err := u.CheckAccess(request) + assert.NoError(t, err) + assert.NotNil(t, response) + assert.True(t, response.Allowed) + assert.False(t, response.Denied) + }(request) + } + + wg.Wait() + }) } func getAuthServerAndAccessInfo(returnCode int, body, clientID, clientSecret string) (*httptest.Server, *AccessInfo) {