diff --git a/mocks/services/oauth/mock_oauth.go b/mocks/services/oauth/mock_oauth.go index af94d96549..145d1bf5a8 100644 --- a/mocks/services/oauth/mock_oauth.go +++ b/mocks/services/oauth/mock_oauth.go @@ -8,7 +8,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - backendconfig "github.com/rudderlabs/rudder-server/backend-config" oauth "github.com/rudderlabs/rudder-server/services/oauth" ) @@ -35,19 +34,19 @@ func (m *MockAuthorizer) EXPECT() *MockAuthorizerMockRecorder { return m.recorder } -// DisableDestination mocks base method. -func (m *MockAuthorizer) DisableDestination(arg0 *backendconfig.DestinationT, arg1, arg2 string) (int, string) { +// AuthStatusToggle mocks base method. +func (m *MockAuthorizer) AuthStatusToggle(arg0 *oauth.AuthStatusToggleParams) (int, string) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DisableDestination", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "AuthStatusToggle", arg0) ret0, _ := ret[0].(int) ret1, _ := ret[1].(string) return ret0, ret1 } -// DisableDestination indicates an expected call of DisableDestination. -func (mr *MockAuthorizerMockRecorder) DisableDestination(arg0, arg1, arg2 interface{}) *gomock.Call { +// AuthStatusToggle indicates an expected call of AuthStatusToggle. +func (mr *MockAuthorizerMockRecorder) AuthStatusToggle(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisableDestination", reflect.TypeOf((*MockAuthorizer)(nil).DisableDestination), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthStatusToggle", reflect.TypeOf((*MockAuthorizer)(nil).AuthStatusToggle), arg0) } // FetchToken mocks base method. diff --git a/regulation-worker/internal/delete/api/api.go b/regulation-worker/internal/delete/api/api.go index bf00480bf3..bc9c14c06f 100644 --- a/regulation-worker/internal/delete/api/api.go +++ b/regulation-worker/internal/delete/api/api.go @@ -15,8 +15,11 @@ import ( "os" "strings" + "github.com/samber/lo" + "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/regulation-worker/internal/model" "github.com/rudderlabs/rudder-server/services/oauth" "github.com/rudderlabs/rudder-server/utils/httputil" @@ -109,17 +112,23 @@ func (api *APIManager) deleteWithRetry(ctx context.Context, job model.Job, desti jobStatus := getJobStatus(resp.StatusCode, jobResp) pkgLogger.Debugf("[%v] Job: %v, JobStatus: %v", destination.Name, job.ID, jobStatus) - if isOAuthEnabled && isTokenExpired(jobResp) && currentOauthRetryAttempt < api.MaxOAuthRefreshRetryAttempts { - err = api.refreshOAuthToken(destination.Name, job.WorkspaceID, oAuthDetail) - if err != nil { - pkgLogger.Error(err) - return model.JobStatus{Status: model.JobStatusFailed, Error: err} + oauthErrJob, oauthErrJobFound := getOAuthErrorJob(jobResp) + + if oauthErrJobFound && isOAuthEnabled { + if oauthErrJob.AuthErrorCategory == oauth.AUTH_STATUS_INACTIVE { + return api.inactivateAuthStatus(&destination, job, oAuthDetail) + } + if oauthErrJob.AuthErrorCategory == oauth.REFRESH_TOKEN && currentOauthRetryAttempt < api.MaxOAuthRefreshRetryAttempts { + err = api.refreshOAuthToken(&destination, job, oAuthDetail) + if err != nil { + pkgLogger.Error(err) + return model.JobStatus{Status: model.JobStatusFailed, Error: err} + } + // retry the request + pkgLogger.Infof("[%v] Retrying deleteRequest job(id: %v) for the whole batch, RetryAttempt: %v", destination.Name, job.ID, currentOauthRetryAttempt+1) + return api.deleteWithRetry(ctx, job, destination, currentOauthRetryAttempt+1) } - // retry the request - pkgLogger.Infof("[%v] Retrying deleteRequest job(id: %v) for the whole batch, RetryAttempt: %v", destination.Name, job.ID, currentOauthRetryAttempt+1) - return api.deleteWithRetry(ctx, job, destination, currentOauthRetryAttempt+1) } - return jobStatus } @@ -160,13 +169,10 @@ func mapJobToPayload(job model.Job, destName string, destConfig map[string]inter } } -func isTokenExpired(jobResponses []JobRespSchema) bool { - for _, jobResponse := range jobResponses { - if jobResponse.AuthErrorCategory == oauth.REFRESH_TOKEN { - return true - } - } - return false +func getOAuthErrorJob(jobResponses []JobRespSchema) (JobRespSchema, bool) { + return lo.Find(jobResponses, func(item JobRespSchema) bool { + return lo.Contains([]string{oauth.AUTH_STATUS_INACTIVE, oauth.REFRESH_TOKEN}, item.AuthErrorCategory) + }) } func setOAuthHeader(secretToken *oauth.AuthResponse, req *http.Request) error { @@ -192,7 +198,7 @@ func (api *APIManager) getOAuthDetail(destDetail *model.Destination, workspaceId EventNamePrefix: "fetch_token", }) if tokenStatusCode != http.StatusOK { - return oauthDetail{}, fmt.Errorf("[%s][FetchToken] Error in Token Fetch statusCode: %d\t error: %s", destDetail.Name, tokenStatusCode, secretToken.Err) + return oauthDetail{}, fmt.Errorf("[%s][FetchToken] Error in Token Fetch statusCode: %d\t error: %s", destDetail.Name, tokenStatusCode, secretToken.ErrorMessage) } return oauthDetail{ id: id, @@ -200,21 +206,47 @@ func (api *APIManager) getOAuthDetail(destDetail *model.Destination, workspaceId }, nil } -func (api *APIManager) refreshOAuthToken(destName, workspaceId string, oAuthDetail oauthDetail) error { +func (api *APIManager) inactivateAuthStatus(destination *model.Destination, job model.Job, oAuthDetail oauthDetail) (jobStatus model.JobStatus) { + dest := &backendconfig.DestinationT{ + ID: destination.DestinationID, + Config: destination.Config, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: destination.Name, + Config: destination.DestDefConfig, + }, + } + _, resp := api.OAuth.AuthStatusToggle(&oauth.AuthStatusToggleParams{ + Destination: dest, + WorkspaceId: job.WorkspaceID, + RudderAccountId: oAuthDetail.id, + AuthStatus: oauth.AuthStatusInactive, + }) + jobStatus.Status = model.JobStatusAborted + jobStatus.Error = fmt.Errorf(resp) + return jobStatus +} + +func (api *APIManager) refreshOAuthToken(destination *model.Destination, job model.Job, oAuthDetail oauthDetail) error { refTokenParams := &oauth.RefreshTokenParams{ Secret: oAuthDetail.secretToken.Account.Secret, - WorkspaceId: workspaceId, + WorkspaceId: job.WorkspaceID, AccountId: oAuthDetail.id, - DestDefName: destName, + DestDefName: destination.Name, EventNamePrefix: "refresh_token", } statusCode, refreshResponse := api.OAuth.RefreshToken(refTokenParams) if statusCode != http.StatusOK { + if refreshResponse.Err == oauth.REF_TOKEN_INVALID_GRANT { + // authStatus should be made inactive + api.inactivateAuthStatus(destination, job, oAuthDetail) + return fmt.Errorf(refreshResponse.ErrorMessage) + } + var refreshRespErr string if refreshResponse != nil { - refreshRespErr = refreshResponse.Err + refreshRespErr = refreshResponse.ErrorMessage } - return fmt.Errorf("[%v] Failed to refresh token for destination in workspace(%v) & account(%v) with %v", destName, workspaceId, oAuthDetail.id, refreshRespErr) + return fmt.Errorf("[%v] Failed to refresh token for destination in workspace(%v) & account(%v) with %v", destination.Name, job.WorkspaceID, oAuthDetail.id, refreshRespErr) } return nil } diff --git a/regulation-worker/internal/delete/api/api_test.go b/regulation-worker/internal/delete/api/api_test.go index 6046c05f28..b4e5686cb1 100644 --- a/regulation-worker/internal/delete/api/api_test.go +++ b/regulation-worker/internal/delete/api/api_test.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -314,7 +315,6 @@ func TestOAuth(t *testing.T) { jobResponse: `[{"status":"successful"}]`, }, }, - cpResponses: []cpResponseParams{ { code: 200, @@ -325,8 +325,7 @@ func TestOAuth(t *testing.T) { response: `{"secret": {"access_token": "refreshed_access_token","refresh_token":"valid_refresh_token"}}`, }, }, - - expectedDeleteStatus: model.JobStatus{Status: model.JobStatusFailed, Error: fmt.Errorf("error: code: 500, body: [{failed [GA] invalid credentials REFRESH_TOKEN}]")}, + expectedDeleteStatus: model.JobStatus{Status: model.JobStatusComplete}, expectedPayload: `[{"jobId":"2","destType":"ga","config":{"rudderDeleteAccountId":"xyz"},"userAttributes":[{"email":"dorowane8n285680461479465450293438@gmail.com","phone":"6463633841","randomKey":"randomValue","userId":"Jermaine1473336609491897794707338"},{"email":"dshirilad8536019424659691213279982@gmail.com","userId":"Mercie8221821544021583104106123"}]}]`, }, { @@ -365,15 +364,13 @@ func TestOAuth(t *testing.T) { }, }, }, - cpResponses: []cpResponseParams{ { code: 500, response: `Internal Server Error`, }, }, - deleteResponses: []deleteResponseParams{{}}, - + deleteResponses: []deleteResponseParams{{}}, expectedDeleteStatus: model.JobStatus{Status: model.JobStatusFailed, Error: fmt.Errorf("[GA][FetchToken] Error in Token Fetch statusCode: 500\t error: Unmarshal of response unsuccessful: Internal Server Error")}, expectedPayload: "", // since request has not gone to transformer at all! }, @@ -420,12 +417,10 @@ func TestOAuth(t *testing.T) { timeout: 2 * time.Second, }, }, - deleteResponses: []deleteResponseParams{{}}, - + deleteResponses: []deleteResponseParams{{}}, oauthHttpClientTimeout: 1 * time.Second, - - expectedDeleteStatus: model.JobStatus{Status: model.JobStatusFailed, Error: fmt.Errorf("Client.Timeout exceeded while awaiting headers")}, - expectedPayload: "", // since request has not gone to transformer at all! + expectedDeleteStatus: model.JobStatus{Status: model.JobStatusFailed, Error: fmt.Errorf("Client.Timeout exceeded while awaiting headers")}, + expectedPayload: "", // since request has not gone to transformer at all! }, { // In this case the request will not even reach transformer, as OAuth is required but we don't have "rudderDeleteAccountId" @@ -576,10 +571,169 @@ func TestOAuth(t *testing.T) { jobResponse: `[{"status":"failed","authErrorCategory":"REFRESH_TOKEN","error":"[GA] invalid credentials"}]`, }, }, - - expectedDeleteStatus: model.JobStatus{Status: model.JobStatusFailed, Error: fmt.Errorf("[GA] invalid credentials")}, + expectedDeleteStatus: model.JobStatus{Status: model.JobStatusFailed, Error: fmt.Errorf("[GA] Failed to refresh token for destination in workspace(1001) & account(xyz) with Unmarshal of response unsuccessful: Post \"__cfgBE_server__/destination/workspaces/1001/accounts/xyz/token\": context deadline exceeded (Client.Timeout exceeded while awaiting headers)")}, expectedPayload: `[{"jobId":"9","destType":"ga","config":{"rudderDeleteAccountId":"xyz"},"userAttributes":[{"email":"dorowane9@gmail.com","phone":"6463633841","randomKey":"randomValue","userId":"Jermaine9"},{"email":"dshirilad9@gmail.com","userId":"Mercie9"}]}]`, }, + + { + name: "when AUTH_STATUS_INACTIVE error happens & authStatus/toggle success, fail the job with Failed status", + job: model.Job{ + ID: 15, + WorkspaceID: "1001", + DestinationID: "1234", + Status: model.JobStatus{Status: model.JobStatusPending}, + Users: []model.User{ + { + ID: "203984798475", + Attributes: map[string]string{ + "phone": "7463633841", + "email": "dreymore@gmail.com", + }, + }, + }, + }, + dest: model.Destination{ + DestinationID: "1234", + Config: map[string]interface{}{ + "rudderDeleteAccountId": "xyz", + "authStatus": "active", + }, + Name: "GA", + DestDefConfig: map[string]interface{}{ + "auth": map[string]interface{}{ + "type": "OAuth", + }, + }, + }, + deleteResponses: []deleteResponseParams{ + { + status: 400, + jobResponse: fmt.Sprintf(`[{"status":"failed","authErrorCategory": "%v", "error": "User does not have sufficient permissions"}]`, oauth.AUTH_STATUS_INACTIVE), + }, + }, + cpResponses: []cpResponseParams{ + // fetch token http request + { + code: 200, + response: `{"secret": {"access_token": "invalid_grant_access_token","refresh_token":"invalid_grant_refresh_token"}}`, + }, + // authStatus inactive http request + { + code: 200, + }, + }, + expectedDeleteStatus: model.JobStatus{Status: model.JobStatusAborted, Error: fmt.Errorf("Problem with user permission or access/refresh token have been revoked")}, + expectedPayload: `[{"jobId":"15","destType":"ga","config":{"authStatus":"active","rudderDeleteAccountId":"xyz"},"userAttributes":[{"email":"dreymore@gmail.com","phone":"7463633841","userId":"203984798475"}]}]`, + }, + { + name: "when AUTH_STATUS_INACTIVE error happens but authStatus/toggle failed, fail the job with Failed status", + job: model.Job{ + ID: 16, + WorkspaceID: "1001", + DestinationID: "1234", + Status: model.JobStatus{Status: model.JobStatusPending}, + Users: []model.User{ + { + ID: "203984798476", + Attributes: map[string]string{ + "phone": "8463633841", + "email": "greymore@gmail.com", + }, + }, + }, + }, + dest: model.Destination{ + DestinationID: "1234", + Config: map[string]interface{}{ + "rudderDeleteAccountId": "xyz", + "authStatus": "active", + }, + Name: "GA", + DestDefConfig: map[string]interface{}{ + "auth": map[string]interface{}{ + "type": "OAuth", + }, + }, + }, + deleteResponses: []deleteResponseParams{ + { + status: 400, + jobResponse: fmt.Sprintf(`[{"status":"failed","authErrorCategory": "%v", "error": "User does not have sufficient permissions"}]`, oauth.AUTH_STATUS_INACTIVE), + }, + }, + cpResponses: []cpResponseParams{ + // fetch token http request + { + code: 200, + response: `{"secret": {"access_token": "invalid_grant_access_token","refresh_token":"invalid_grant_refresh_token"}}`, + }, + // authStatus inactive http request + { + code: 400, + response: `{"message": "AuthStatus toggle skipped as already request in-progress: (1234, 1001)"}`, + }, + }, + expectedDeleteStatus: model.JobStatus{Status: model.JobStatusAborted, Error: fmt.Errorf("Problem with user permission or access/refresh token have been revoked")}, + expectedPayload: `[{"jobId":"16","destType":"ga","config":{"authStatus":"active","rudderDeleteAccountId":"xyz"},"userAttributes":[{"email":"greymore@gmail.com","phone":"8463633841","userId":"203984798476"}]}]`, + }, + + { + name: "when REFRESH_TOKEN error happens but refreshing token fails due to token revocation, fail the job with Failed status", + job: model.Job{ + ID: 17, + WorkspaceID: "1001", + DestinationID: "1234", + Status: model.JobStatus{Status: model.JobStatusPending}, + Users: []model.User{ + { + ID: "203984798477", + Attributes: map[string]string{ + "phone": "8463633841", + "email": "greymore@gmail.com", + }, + }, + }, + }, + dest: model.Destination{ + DestinationID: "1234", + Config: map[string]interface{}{ + "rudderDeleteAccountId": "xyz", + "authStatus": "active", + }, + Name: "GA", + DestDefConfig: map[string]interface{}{ + "auth": map[string]interface{}{ + "type": "OAuth", + }, + }, + }, + deleteResponses: []deleteResponseParams{ + { + status: 500, + jobResponse: `[{"status":"failed","authErrorCategory":"REFRESH_TOKEN", "error": "[GA] invalid credentials"}]`, + }, + }, + + cpResponses: []cpResponseParams{ + // fetch token http request + { + code: 200, + response: `{"secret": {"access_token": "invalid_grant_access_token","refresh_token":"invalid_grant_refresh_token"}}`, + }, + // refresh token http request + { + code: 403, + response: `{"status":403,"body":{"message":"[google_analytics] \"invalid_grant\" error, refresh token has been revoked","status":403,"code":"ref_token_invalid_grant"},"code":"ref_token_invalid_grant","access_token":"invalid_grant_access_token","refresh_token":"invalid_grant_refresh_token","developer_token":"dev_token"}`, + }, + // authStatus inactive http request + { + code: 200, + }, + }, + + expectedDeleteStatus: model.JobStatus{Status: model.JobStatusFailed, Error: fmt.Errorf("[google_analytics] \"invalid_grant\" error, refresh token has been revoked")}, + expectedPayload: `[{"jobId":"17","destType":"ga","config":{"authStatus":"active","rudderDeleteAccountId":"xyz"},"userAttributes":[{"email":"greymore@gmail.com","phone":"8463633841","userId":"203984798477"}]}]`, + }, } for _, tt := range tests { @@ -606,15 +760,17 @@ func TestOAuth(t *testing.T) { oauth.Init() OAuth := oauth.NewOAuthErrorHandler(mockBackendConfig, oauth.WithRudderFlow(oauth.RudderFlow_Delete), oauth.WithOAuthClientTimeout(tt.oauthHttpClientTimeout)) api := api.APIManager{ - Client: &http.Client{}, - DestTransformURL: svr.URL, - OAuth: OAuth, + Client: &http.Client{}, + DestTransformURL: svr.URL, + OAuth: OAuth, + MaxOAuthRefreshRetryAttempts: 1, } status := api.Delete(ctx, tt.job, tt.dest) require.Equal(t, tt.expectedDeleteStatus.Status, status.Status) if tt.expectedDeleteStatus.Status != model.JobStatusComplete { - require.Contains(t, status.Error.Error(), tt.expectedDeleteStatus.Error.Error()) + jobError := strings.Replace(tt.expectedDeleteStatus.Error.Error(), "__cfgBE_server__", cfgBeSrv.URL, 1) + require.Contains(t, status.Error.Error(), jobError) } // require.Equal(t, tt.expectedDeleteStatus, status) // TODO: Compare input payload for all "/deleteUsers" requests @@ -650,7 +806,37 @@ func (cpRespProducer *cpResponseProducer) mockCpRequests() *chi.Mux { param := chi.URLParam(req, reqParam) if param == "" { // This case wouldn't occur I guess - http.Error(w, fmt.Sprintf("Wrong url being sent: %v", reqParam), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("Wrong url being sent: %v", reqParam), http.StatusBadRequest) + return + } + } + + cpResp := cpRespProducer.GetNext() + // sleep is being used to mimic the waiting in actual transformer response + if cpResp.timeout > 0 { + time.Sleep(cpResp.timeout) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(cpResp.code) + // Lint error fix + _, err := w.Write([]byte(cpResp.response)) + if err != nil { + http.Error(w, fmt.Sprintf("Provided response is faulty, please check it. Err: %v", err.Error()), http.StatusInternalServerError) + return + } + }) + + srvMux.HandleFunc("/workspaces/{workspaceId}/destinations/{destinationId}/authStatus/toggle", func(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPut { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + // iterating over request parameters + for _, reqParam := range []string{"workspaceId", "destinationId"} { + param := chi.URLParam(req, reqParam) + if param == "" { + // This case wouldn't occur I guess + http.Error(w, fmt.Sprintf("Wrong url being sent: %v", reqParam), http.StatusNotFound) return } } @@ -665,7 +851,6 @@ func (cpRespProducer *cpResponseProducer) mockCpRequests() *chi.Mux { // Lint error fix _, err := w.Write([]byte(cpResp.response)) if err != nil { - fmt.Printf("I'm here!!!! Some shitty response!!") http.Error(w, fmt.Sprintf("Provided response is faulty, please check it. Err: %v", err.Error()), http.StatusInternalServerError) return } @@ -708,13 +893,13 @@ func (delRespProducer *deleteResponseProducer) mockDeleteRequests() *chi.Mux { buf := new(bytes.Buffer) _, bufErr := buf.ReadFrom(req.Body) if bufErr != nil { - http.Error(w, bufErr.Error(), http.StatusInternalServerError) + http.Error(w, bufErr.Error(), http.StatusBadRequest) return } + delResp := delRespProducer.GetNext() + // useful in validating the payload(sent in request body to transformer) delRespProducer.GetCurrent().actualPayload = buf.String() - - delResp := delRespProducer.GetNext() // sleep is being used to mimic the waiting in actual transformer response if delResp.timeout > 0 { time.Sleep(delResp.timeout) @@ -724,7 +909,6 @@ func (delRespProducer *deleteResponseProducer) mockDeleteRequests() *chi.Mux { // Lint error fix _, err := w.Write([]byte(delResp.jobResponse)) if err != nil { - fmt.Printf("I'm here!!!! Some shitty response!!") http.Error(w, fmt.Sprintf("Provided response is faulty, please check it. Err: %v", err.Error()), http.StatusInternalServerError) return } diff --git a/router/handle.go b/router/handle.go index 3380d4b29d..2970456bcb 100644 --- a/router/handle.go +++ b/router/handle.go @@ -614,8 +614,10 @@ func (rt *Handle) handleOAuthDestResponse(params *HandleDestOAuthRespParams) (in return trRespStatusCode, trRespBody } switch destErrOutput.AuthErrorCategory { - case oauth.DISABLE_DEST: - return rt.execDisableDestination(&destinationJob.Destination, workspaceID, trRespBody, rudderAccountID) + case oauth.AUTH_STATUS_INACTIVE: + authStatusStCd := rt.updateAuthStatusToInactive(&destinationJob.Destination, workspaceID, rudderAccountID) + authStatusMsg := gjson.Get(trRespBody, "message").Raw + return authStatusStCd, authStatusMsg case oauth.REFRESH_TOKEN: var refSecret *oauth.AuthResponse refTokenParams := &oauth.RefreshTokenParams{ @@ -628,21 +630,21 @@ func (rt *Handle) handleOAuthDestResponse(params *HandleDestOAuthRespParams) (in } errCatStatusCode, refSecret = rt.oauth.RefreshToken(refTokenParams) refSec := *refSecret - if routerutils.IsNotEmptyString(refSec.Err) && refSec.Err == oauth.INVALID_REFRESH_TOKEN_GRANT { + if routerutils.IsNotEmptyString(refSec.Err) && refSec.Err == oauth.REF_TOKEN_INVALID_GRANT { // In-case the refresh token has been revoked, this error comes in // Even trying to refresh the token also doesn't work here. Hence, this would be more ideal to Abort Events // As well as to disable destination as well. // Alert the user in this error as well, to check if the refresh token also has been revoked & fix it - disableStCd, _ := rt.execDisableDestination(&destinationJob.Destination, workspaceID, trRespBody, rudderAccountID) - stats.Default.NewTaggedStat(oauth.INVALID_REFRESH_TOKEN_GRANT, stats.CountType, stats.Tags{ + authStatusInactiveStCode := rt.updateAuthStatusToInactive(&destinationJob.Destination, workspaceID, rudderAccountID) + stats.Default.NewTaggedStat(oauth.REF_TOKEN_INVALID_GRANT, stats.CountType, stats.Tags{ "destinationId": destinationJob.Destination.ID, "workspaceId": refTokenParams.WorkspaceId, "accountId": refTokenParams.AccountId, "destType": refTokenParams.DestDefName, "flowType": string(oauth.RudderFlow_Delivery), }).Increment() - rt.logger.Errorf(`[OAuth request] Aborting the event as %v`, oauth.INVALID_REFRESH_TOKEN_GRANT) - return disableStCd, refSec.Err + rt.logger.Errorf(`[OAuth request] Aborting the event as %v`, oauth.REF_TOKEN_INVALID_GRANT) + return authStatusInactiveStCode, refSecret.ErrorMessage } // Error while refreshing the token or Has an error while refreshing or sending empty access token if errCatStatusCode != http.StatusOK || routerutils.IsNotEmptyString(refSec.Err) { @@ -656,24 +658,25 @@ func (rt *Handle) handleOAuthDestResponse(params *HandleDestOAuthRespParams) (in return trRespStatusCode, trRespBody } -func (rt *Handle) execDisableDestination(destination *backendconfig.DestinationT, workspaceID, destResBody, rudderAccountId string) (int, string) { - disableDestStatTags := stats.Tags{ +func (rt *Handle) updateAuthStatusToInactive(destination *backendconfig.DestinationT, workspaceID, rudderAccountId string) int { + inactiveAuthStatusStatTags := stats.Tags{ "id": destination.ID, "destType": destination.DestinationDefinition.Name, "workspaceId": workspaceID, "success": "true", "flowType": string(oauth.RudderFlow_Delivery), } - errCatStatusCode, errCatResponse := rt.oauth.DisableDestination(destination, workspaceID, rudderAccountId) + errCatStatusCode, _ := rt.oauth.AuthStatusToggle(&oauth.AuthStatusToggleParams{ + Destination: destination, + WorkspaceId: workspaceID, + RudderAccountId: rudderAccountId, + AuthStatus: oauth.AuthStatusInactive, + }) if errCatStatusCode != http.StatusOK { - // Error while disabling a destination - // High-Priority notification to rudderstack needs to be sent - disableDestStatTags["success"] = "false" - stats.Default.NewTaggedStat("disable_destination_category_count", stats.CountType, disableDestStatTags).Increment() - return http.StatusBadRequest, errCatResponse - } - // High-Priority notification to workspace(& rudderstack) needs to be sent - stats.Default.NewTaggedStat("disable_destination_category_count", stats.CountType, disableDestStatTags).Increment() + // Error while inactivating authStatus + inactiveAuthStatusStatTags["success"] = "false" + } + stats.Default.NewTaggedStat("auth_status_inactive_category_count", stats.CountType, inactiveAuthStatusStatTags).Increment() // Abort the jobs as the destination is disabled - return http.StatusBadRequest, destResBody + return http.StatusBadRequest } diff --git a/services/oauth/oauth.go b/services/oauth/oauth.go index 58e000e99a..ce53270fa1 100644 --- a/services/oauth/oauth.go +++ b/services/oauth/oauth.go @@ -4,6 +4,7 @@ package oauth import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -38,6 +39,8 @@ const ( DeleteAccountIdKey = "rudderDeleteAccountId" DeliveryAccountIdKey = "rudderAccountId" + + AuthStatusInactive = "inactive" ) type AccountSecret struct { @@ -45,8 +48,9 @@ type AccountSecret struct { Secret json.RawMessage `json:"secret"` } type AuthResponse struct { - Account AccountSecret - Err string + Account AccountSecret + Err string + ErrorMessage string } type OAuthStats struct { @@ -67,6 +71,17 @@ type DisableDestinationResponse struct { DestinationId string `json:"id"` } +type AuthStatusToggleResponse struct { + Message string `json:"message,omitempty"` +} + +type AuthStatusToggleParams struct { + Destination *backendconfig.DestinationT + WorkspaceId string + RudderAccountId string + AuthStatus string +} + type RefreshTokenParams struct { AccountId string WorkspaceId string @@ -78,21 +93,21 @@ type RefreshTokenParams struct { // OAuthErrResHandler is the handle for this class type OAuthErrResHandler struct { - tr *http.Transport - client *http.Client - logger logger.Logger - destLockMap map[string]*sync.RWMutex // This mutex map is used for disable destination locking - accountLockMap map[string]*sync.RWMutex // This mutex map is used for refresh token locking - lockMapWMutex *sync.RWMutex // This mutex is used to prevent concurrent writes in lockMap(s) mentioned in the struct - destAuthInfoMap map[string]*AuthResponse - refreshActiveMap map[string]bool // Used to check if a refresh request for an account is already InProgress - disableDestActiveMap map[string]bool // Used to check if a disable destination request for a destination is already InProgress - tokenProvider tokenProvider - rudderFlowType RudderFlow + tr *http.Transport + client *http.Client + logger logger.Logger + destLockMap map[string]*sync.RWMutex // This mutex map is used for disable destination locking + accountLockMap map[string]*sync.RWMutex // This mutex map is used for refresh token locking + lockMapWMutex *sync.RWMutex // This mutex is used to prevent concurrent writes in lockMap(s) mentioned in the struct + destAuthInfoMap map[string]*AuthResponse + refreshActiveMap map[string]bool // Used to check if a refresh request for an account is already InProgress + authStatusUpdateActiveMap map[string]bool // Used to check if a authStatusInactive request for a destination is already InProgress + tokenProvider tokenProvider + rudderFlowType RudderFlow } type Authorizer interface { - DisableDestination(destination *backendconfig.DestinationT, workspaceId, rudderAccountId string) (statusCode int, resBody string) + AuthStatusToggle(*AuthStatusToggleParams) (int, string) RefreshToken(refTokenParams *RefreshTokenParams) (int, *AuthResponse) FetchToken(fetchTokenParams *RefreshTokenParams) (int, *AuthResponse) } @@ -113,11 +128,16 @@ var ( ) const ( - DISABLE_DEST = "DISABLE_DESTINATION" - REFRESH_TOKEN = "REFRESH_TOKEN" - INVALID_REFRESH_TOKEN_GRANT = "refresh_token_invalid_grant" + REFRESH_TOKEN = "REFRESH_TOKEN" + // Identifier to be sent from destination(during transformation/delivery) + AUTH_STATUS_INACTIVE = "AUTH_STATUS_INACTIVE" + + // Identifier for invalid_grant or access_denied errors(during refreshing the token) + REF_TOKEN_INVALID_GRANT = "ref_token_invalid_grant" ) +var ErrPermissionOrTokenRevoked = errors.New("Problem with user permission or access/refresh token have been revoked") + // This struct only exists for marshalling and sending payload to control-plane type RefreshTokenBodyParams struct { HasExpired bool `json:"hasExpired"` @@ -155,13 +175,13 @@ func NewOAuthErrorHandler(provider tokenProvider, options ...func(*OAuthErrResHa tr: &http.Transport{}, client: &http.Client{Timeout: config.GetDuration("HttpClient.oauth.timeout", 30, time.Second)}, // This timeout is kind of modifiable & it seemed like 10 mins for this is too much! - destLockMap: make(map[string]*sync.RWMutex), - accountLockMap: make(map[string]*sync.RWMutex), - lockMapWMutex: &sync.RWMutex{}, - destAuthInfoMap: make(map[string]*AuthResponse), - refreshActiveMap: make(map[string]bool), - disableDestActiveMap: make(map[string]bool), - rudderFlowType: RudderFlow_Delivery, + destLockMap: make(map[string]*sync.RWMutex), + accountLockMap: make(map[string]*sync.RWMutex), + lockMapWMutex: &sync.RWMutex{}, + destAuthInfoMap: make(map[string]*AuthResponse), + refreshActiveMap: make(map[string]bool), + authStatusUpdateActiveMap: make(map[string]bool), + rudderFlowType: RudderFlow_Delivery, } for _, opt := range options { opt(oAuthErrResHandler) @@ -339,18 +359,20 @@ func (authErrHandler *OAuthErrResHandler) fetchAccountInfoFromCp(refTokenParams return http.StatusInternalServerError } - if refErrMsg := getRefreshTokenErrResp(response, &accountSecret); router_utils.IsNotEmptyString(refErrMsg) { + if errType, refErrMsg := authErrHandler.getRefreshTokenErrResp(response, &accountSecret); router_utils.IsNotEmptyString(refErrMsg) { if _, ok := authErrHandler.destAuthInfoMap[refTokenParams.AccountId]; !ok { authErrHandler.destAuthInfoMap[refTokenParams.AccountId] = &AuthResponse{ - Err: refErrMsg, + Err: errType, + ErrorMessage: refErrMsg, } } else { - authErrHandler.destAuthInfoMap[refTokenParams.AccountId].Err = refErrMsg + authErrHandler.destAuthInfoMap[refTokenParams.AccountId].Err = errType + authErrHandler.destAuthInfoMap[refTokenParams.AccountId].ErrorMessage = refErrMsg } authStats.statName = fmt.Sprintf("%s_failure", refTokenParams.EventNamePrefix) authStats.errorMessage = refErrMsg authStats.SendCountStat() - if refErrMsg == INVALID_REFRESH_TOKEN_GRANT { + if refErrMsg == REF_TOKEN_INVALID_GRANT { // Should abort the event as refresh is not going to work // until we have new refresh token for the account return http.StatusBadRequest @@ -368,15 +390,24 @@ func (authErrHandler *OAuthErrResHandler) fetchAccountInfoFromCp(refTokenParams return http.StatusOK } -func getRefreshTokenErrResp(response string, accountSecret *AccountSecret) (message string) { +func (authErrHandler *OAuthErrResHandler) getRefreshTokenErrResp(response string, accountSecret *AccountSecret) (errorType, message string) { if err := json.Unmarshal([]byte(response), &accountSecret); err != nil { // Some problem with AccountSecret unmarshalling message = fmt.Sprintf("Unmarshal of response unsuccessful: %v", response) - } else if gjson.Get(response, "body.code").String() == INVALID_REFRESH_TOKEN_GRANT { + errorType = "unmarshallableResponse" + } else if gjson.Get(response, "body.code").String() == REF_TOKEN_INVALID_GRANT { // User (or) AccessToken (or) RefreshToken has been revoked - message = INVALID_REFRESH_TOKEN_GRANT + bodyMsg := gjson.Get(response, "body.message").String() + if bodyMsg == "" { + // Default message + authErrHandler.logger.Debugf("Failed with error response: %v\n", response) + message = ErrPermissionOrTokenRevoked.Error() + } else { + message = bodyMsg + } + errorType = REF_TOKEN_INVALID_GRANT } - return message + return errorType, message } func (authStats *OAuthStats) SendTimerStats(startTime time.Time) { @@ -406,92 +437,99 @@ func (refStats *OAuthStats) SendCountStat() { }).Increment() } -func (authErrHandler *OAuthErrResHandler) DisableDestination(destination *backendconfig.DestinationT, workspaceId, rudderAccountId string) (statusCode int, respBody string) { +func (authErrHandler *OAuthErrResHandler) AuthStatusToggle(params *AuthStatusToggleParams) (statusCode int, respBody string) { authErrHandlerTimeStart := time.Now() - destinationId := destination.ID - disableDestMutex := authErrHandler.getKeyMutex(authErrHandler.destLockMap, destinationId) + destinationId := params.Destination.ID + authStatusToggleMutex := authErrHandler.getKeyMutex(authErrHandler.destLockMap, destinationId) + + getStatName := func(statName string) string { + return fmt.Sprintf("auth_status_%v_%v", statName, params.AuthStatus) + } - disableDestStats := &OAuthStats{ + authStatusToggleStats := &OAuthStats{ id: destinationId, - workspaceId: workspaceId, + workspaceId: params.WorkspaceId, rudderCategory: "destination", statName: "", isCallToCpApi: false, - authErrCategory: DISABLE_DEST, + authErrCategory: AUTH_STATUS_INACTIVE, errorMessage: "", - destDefName: destination.DestinationDefinition.Name, + destDefName: params.Destination.DestinationDefinition.Name, flowType: authErrHandler.rudderFlowType, } defer func() { - disableDestStats.statName = "disable_destination_total_req_latency" - disableDestStats.isCallToCpApi = false - disableDestStats.SendTimerStats(authErrHandlerTimeStart) + authStatusToggleStats.statName = getStatName("total_req_latency") + authStatusToggleStats.isCallToCpApi = false + authStatusToggleStats.SendTimerStats(authErrHandlerTimeStart) }() - disableDestMutex.Lock() - isDisableDestActive, isDisableDestReqPresent := authErrHandler.disableDestActiveMap[destinationId] - disableActiveReq := strconv.FormatBool(isDisableDestReqPresent && isDisableDestActive) - if isDisableDestReqPresent && isDisableDestActive { - disableDestMutex.Unlock() - authErrHandler.logger.Debugf("[%s request] :: Disable Destination Active : %s\n", loggerNm, disableActiveReq) - return http.StatusOK, fmt.Sprintf(`{response: {isDisabled: %v, activeRequest: %v}`, false, disableActiveReq) + authStatusToggleMutex.Lock() + isAuthStatusUpdateActive, isAuthStatusUpdateReqPresent := authErrHandler.authStatusUpdateActiveMap[destinationId] + authStatusUpdateActiveReq := strconv.FormatBool(isAuthStatusUpdateReqPresent && isAuthStatusUpdateActive) + if isAuthStatusUpdateReqPresent && isAuthStatusUpdateActive { + authStatusToggleMutex.Unlock() + authErrHandler.logger.Debugf("[%s request] :: AuthStatusInactive request Active : %s\n", loggerNm, authStatusUpdateActiveReq) + return http.StatusConflict, ErrPermissionOrTokenRevoked.Error() } - authErrHandler.disableDestActiveMap[destinationId] = true - disableDestMutex.Unlock() + authErrHandler.authStatusUpdateActiveMap[destinationId] = true + authStatusToggleMutex.Unlock() defer func() { - disableDestMutex.Lock() - authErrHandler.disableDestActiveMap[destinationId] = false - authErrHandler.logger.Debugf("[%s request] :: Disable request is inactive!", loggerNm) - disableDestMutex.Unlock() + authStatusToggleMutex.Lock() + authErrHandler.authStatusUpdateActiveMap[destinationId] = false + authErrHandler.logger.Debugf("[%s request] :: AuthStatusInactive request is inactive!", loggerNm) + authStatusToggleMutex.Unlock() + // After trying to inactivate authStatus for destination, need to remove existing accessToken(from in-memory cache) + // This is being done to obtain new token after an update such as re-authorisation is done + accountMutex := authErrHandler.getKeyMutex(authErrHandler.accountLockMap, params.RudderAccountId) + accountMutex.Lock() + delete(authErrHandler.destAuthInfoMap, params.RudderAccountId) + accountMutex.Unlock() }() - disableURL := fmt.Sprintf("%s/workspaces/%s/destinations/%s/disable", configBEURL, workspaceId, destinationId) - disableCpReq := &ControlPlaneRequestT{ - Url: disableURL, - Method: http.MethodDelete, - destName: destination.DestinationDefinition.Name, - RequestType: "Disable destination", + authStatusToggleUrl := fmt.Sprintf("%s/workspaces/%s/destinations/%s/authStatus/toggle", configBEURL, params.WorkspaceId, destinationId) + + authStatusInactiveCpReq := &ControlPlaneRequestT{ + Url: authStatusToggleUrl, + Method: http.MethodPut, + Body: `{"authStatus": "inactive"}`, + ContentType: "application/json", + destName: params.Destination.DestinationDefinition.Name, + RequestType: "Auth Status inactive", } - disableDestStats.statName = "disable_destination_request_sent" - disableDestStats.isCallToCpApi = true - disableDestStats.SendCountStat() + authStatusToggleStats.statName = getStatName("request_sent") + authStatusToggleStats.isCallToCpApi = true + authStatusToggleStats.SendCountStat() cpiCallStartTime := time.Now() - statusCode, respBody = authErrHandler.cpApiCall(disableCpReq) - disableDestStats.statName = `disable_destination_request_latency` - defer disableDestStats.SendTimerStats(cpiCallStartTime) - authErrHandler.logger.Debugf(`Response from CP(stCd: %v) for disable dest req: %v`, statusCode, respBody) - - var disableDestRes *DisableDestinationResponse - if disableErr := json.Unmarshal([]byte(respBody), &disableDestRes); disableErr != nil || !router_utils.IsNotEmptyString(disableDestRes.DestinationId) { + statusCode, respBody = authErrHandler.cpApiCall(authStatusInactiveCpReq) + authStatusToggleStats.statName = getStatName("request_latency") + defer authStatusToggleStats.SendTimerStats(cpiCallStartTime) + authErrHandler.logger.Errorf(`Response from CP(stCd: %v) for auth status inactive req: %v`, statusCode, respBody) + + var authStatusToggleRes *AuthStatusToggleResponse + unmarshalErr := json.Unmarshal([]byte(respBody), &authStatusToggleRes) + if router_utils.IsNotEmptyString(respBody) && (unmarshalErr != nil || !router_utils.IsNotEmptyString(authStatusToggleRes.Message) || statusCode != http.StatusOK) { var msg string - if disableErr != nil { - msg = disableErr.Error() + if unmarshalErr != nil { + msg = unmarshalErr.Error() } else { - msg = "Could not disable the destination" + msg = fmt.Sprintf("Could not update authStatus to inactive for destination: %v", authStatusToggleRes.Message) } - disableDestStats.statName = "disable_destination_failure" - disableDestStats.errorMessage = msg - disableDestStats.SendCountStat() - return http.StatusBadRequest, msg + authStatusToggleStats.statName = getStatName("failure") + authStatusToggleStats.errorMessage = msg + authStatusToggleStats.SendCountStat() + return http.StatusBadRequest, ErrPermissionOrTokenRevoked.Error() } - authErrHandler.logger.Debugf("[%s request] :: (Write) Disable Response received : %s\n", loggerNm, respBody) - disableDestStats.statName = "disable_destination_success" - disableDestStats.errorMessage = "" - disableDestStats.SendCountStat() - - // After a successfully disabling the destination, need to remove existing accessToken(from in-memory cache) - // This is being done to obtain new token after re-enabling disabled destination - accountMutex := authErrHandler.getKeyMutex(authErrHandler.accountLockMap, rudderAccountId) - accountMutex.Lock() - defer accountMutex.Unlock() - delete(authErrHandler.destAuthInfoMap, rudderAccountId) + authErrHandler.logger.Errorf("[%s request] :: (Write) auth status inactive Response received : %s\n", loggerNm, respBody) + authStatusToggleStats.statName = getStatName("success") + authStatusToggleStats.errorMessage = "" + authStatusToggleStats.SendCountStat() - return statusCode, fmt.Sprintf(`{response: {isDisabled: %v, activeRequest: %v}`, !disableDestRes.Enabled, false) + return http.StatusBadRequest, ErrPermissionOrTokenRevoked.Error() } func processResponse(resp *http.Response) (statusCode int, respBody string) { diff --git a/services/oauth/oauth_test.go b/services/oauth/oauth_test.go new file mode 100644 index 0000000000..9f8dfec6d3 --- /dev/null +++ b/services/oauth/oauth_test.go @@ -0,0 +1,164 @@ +package oauth_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/golang/mock/gomock" + "github.com/samber/lo" + + "github.com/stretchr/testify/require" + + backendconfig "github.com/rudderlabs/rudder-server/backend-config" + mocksBackendConfig "github.com/rudderlabs/rudder-server/mocks/backend-config" + "github.com/rudderlabs/rudder-server/services/oauth" +) + +type cpResponseParams struct { + timeout time.Duration + code int + response string +} +type cpResponseProducer struct { + responses []cpResponseParams + callCount int +} + +func (s *cpResponseProducer) GetNext() cpResponseParams { + if s.callCount >= len(s.responses) { + panic("ran out of responses") + } + cpResp := s.responses[s.callCount] + s.callCount++ + return cpResp +} + +func (cpRespProducer *cpResponseProducer) mockCpRequests() *chi.Mux { + srvMux := chi.NewMux() + srvMux.HandleFunc("/destination/workspaces/{workspaceId}/accounts/{accountId}/token", func(w http.ResponseWriter, req *http.Request) { + // iterating over request parameters + for _, reqParam := range []string{"workspaceId", "accountId"} { + param := chi.URLParam(req, reqParam) + if param == "" { + // This case wouldn't occur I guess + http.Error(w, fmt.Sprintf("Wrong url being sent: %v", reqParam), http.StatusBadRequest) + return + } + } + + cpResp := cpRespProducer.GetNext() + // sleep is being used to mimic the waiting in actual transformer response + if cpResp.timeout > 0 { + time.Sleep(cpResp.timeout) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(cpResp.code) + // Lint error fix + _, err := w.Write([]byte(cpResp.response)) + if err != nil { + http.Error(w, fmt.Sprintf("Provided response is faulty, please check it. Err: %v", err.Error()), http.StatusInternalServerError) + return + } + }) + + srvMux.HandleFunc("/workspaces/{workspaceId}/destinations/{destinationId}/authStatus/toggle", func(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPut { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + // iterating over request parameters + for _, reqParam := range []string{"workspaceId", "destinationId"} { + param := chi.URLParam(req, reqParam) + if param == "" { + // This case wouldn't occur I guess + http.Error(w, fmt.Sprintf("Wrong url being sent: %v", reqParam), http.StatusNotFound) + return + } + } + + cpResp := cpRespProducer.GetNext() + // sleep is being used to mimic the waiting in actual transformer response + if cpResp.timeout > 0 { + time.Sleep(cpResp.timeout) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(cpResp.code) + // Lint error fix + _, err := w.Write([]byte(cpResp.response)) + if err != nil { + http.Error(w, fmt.Sprintf("Provided response is faulty, please check it. Err: %v", err.Error()), http.StatusInternalServerError) + return + } + }) + return srvMux +} + +func TestMultipleRequestsForOAuth(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockBackendConfig := mocksBackendConfig.NewMockBackendConfig(mockCtrl) + mockBackendConfig.EXPECT().AccessToken().AnyTimes() + + t.Run("multiple authStatusInactive requests", func(t *testing.T) { + cpRespProducer := &cpResponseProducer{ + responses: []cpResponseParams{ + { + timeout: 1 * time.Second, + code: 200, + }, + }, + } + cfgBeSrv := httptest.NewServer(cpRespProducer.mockCpRequests()) + + defer cfgBeSrv.Close() + + t.Setenv("CONFIG_BACKEND_URL", cfgBeSrv.URL) + t.Setenv("CONFIG_BACKEND_TOKEN", "config_backend_token") + + backendconfig.Init() + oauth.Init() + OAuth := oauth.NewOAuthErrorHandler(mockBackendConfig, oauth.WithRudderFlow(oauth.RudderFlow_Delete)) + + totalGoRoutines := 5 + var wg sync.WaitGroup + var allJobStatus []int + + dest := &backendconfig.DestinationT{ + ID: "dId", + Config: map[string]interface{}{ + "rudderDeleteAccountId": "accountId", + }, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: "GA", + Config: map[string]interface{}{ + "auth": map[string]interface{}{ + "type": "OAuth", + }, + }, + }, + } + + for i := 0; i < totalGoRoutines; i++ { + wg.Add(1) + go func() { + status, _ := OAuth.AuthStatusToggle(&oauth.AuthStatusToggleParams{ + Destination: dest, + WorkspaceId: "wspId", + RudderAccountId: "accountId", + AuthStatus: oauth.AuthStatusInactive, + }) + allJobStatus = append(allJobStatus, status) + wg.Done() + }() + } + wg.Wait() + countMap := lo.CountValues(allJobStatus) + + require.Equal(t, countMap[http.StatusConflict], totalGoRoutines-1) + require.Equal(t, countMap[http.StatusBadRequest], 1) + }) +}