diff --git a/server/controllers/api_controller.go b/server/controllers/api_controller.go index 29037ec9a0..0694e72b22 100644 --- a/server/controllers/api_controller.go +++ b/server/controllers/api_controller.go @@ -33,6 +33,8 @@ type APIController struct { RepoAllowlistChecker *events.RepoAllowlistChecker Scope tally.Scope VCSClient vcs.Client + WorkingDir events.WorkingDir + WorkingDirLocker events.WorkingDirLocker CommitStatusUpdater events.CommitStatusUpdater } @@ -91,12 +93,18 @@ func (a *APIController) Plan(w http.ResponseWriter, r *http.Request) { return } + err = a.apiSetup(ctx) + if err != nil { + a.apiReportError(w, http.StatusInternalServerError, err) + return + } + result, err := a.apiPlan(request, ctx) if err != nil { a.apiReportError(w, http.StatusInternalServerError, err) return } - defer a.Locker.UnlockByPull(ctx.HeadRepo.FullName, 0) // nolint: errcheck + defer a.Locker.UnlockByPull(ctx.HeadRepo.FullName, ctx.Pull.Num) // nolint: errcheck if result.HasErrors() { code = http.StatusInternalServerError } @@ -119,13 +127,19 @@ func (a *APIController) Apply(w http.ResponseWriter, r *http.Request) { return } + err = a.apiSetup(ctx) + if err != nil { + a.apiReportError(w, http.StatusInternalServerError, err) + return + } + // We must first make the plan for all projects _, err = a.apiPlan(request, ctx) if err != nil { a.apiReportError(w, http.StatusInternalServerError, err) return } - defer a.Locker.UnlockByPull(ctx.HeadRepo.FullName, 0) // nolint: errcheck + defer a.Locker.UnlockByPull(ctx.HeadRepo.FullName, ctx.Pull.Num) // nolint: errcheck // We can now prepare and run the apply step result, err := a.apiApply(request, ctx) @@ -145,6 +159,27 @@ func (a *APIController) Apply(w http.ResponseWriter, r *http.Request) { a.respond(w, logging.Warn, code, "%s", string(response)) } +func (a *APIController) apiSetup(ctx *command.Context) error { + pull := ctx.Pull + baseRepo := ctx.Pull.BaseRepo + headRepo := ctx.HeadRepo + + unlockFn, err := a.WorkingDirLocker.TryLock(baseRepo.FullName, pull.Num, events.DefaultWorkspace, events.DefaultRepoRelDir) + if err != nil { + return err + } + ctx.Log.Debug("got workspace lock") + defer unlockFn() + + // ensure workingDir is present + _, _, err = a.WorkingDir.Clone(ctx.Log, headRepo, pull, events.DefaultWorkspace) + if err != nil { + return err + } + + return nil +} + func (a *APIController) apiPlan(request *APIRequest, ctx *command.Context) (*command.Result, error) { cmds, cc, err := request.getCommands(ctx, a.ProjectCommandBuilder.BuildPlanCommands) if err != nil { diff --git a/server/controllers/api_controller_test.go b/server/controllers/api_controller_test.go index 778c41bee2..02a35dbcbd 100644 --- a/server/controllers/api_controller_test.go +++ b/server/controllers/api_controller_test.go @@ -25,49 +25,194 @@ const atlantisToken = "token" func TestAPIController_Plan(t *testing.T) { ac, projectCommandBuilder, projectCommandRunner := setup(t) - body, _ := json.Marshal(controllers.APIRequest{ - Repository: "Repo", - Ref: "main", - Type: "Gitlab", - Projects: []string{"default"}, - }) - req, _ := http.NewRequest("POST", "", bytes.NewBuffer(body)) - req.Header.Set(atlantisTokenHeader, atlantisToken) - w := httptest.NewRecorder() - ac.Plan(w, req) - ResponseContains(t, w, http.StatusOK, "") - projectCommandBuilder.VerifyWasCalledOnce().BuildPlanCommands(Any[*command.Context](), Any[*events.CommentCommand]()) - projectCommandRunner.VerifyWasCalledOnce().Plan(Any[command.ProjectContext]()) + + cases := []struct { + repository string + ref string + vcsType string + pr int + projects []string + paths []struct { + Directory string + Workspace string + } + }{ + { + repository: "Repo", + ref: "main", + vcsType: "Gitlab", + projects: []string{"default"}, + }, + { + repository: "Repo", + ref: "main", + vcsType: "Gitlab", + pr: 1, + }, + { + repository: "Repo", + ref: "main", + vcsType: "Gitlab", + paths: []struct { + Directory string + Workspace string + }{ + { + Directory: ".", + Workspace: "myworkspace", + }, + { + Directory: "./myworkspace2", + Workspace: "myworkspace2", + }, + }, + }, + { + repository: "Repo", + ref: "main", + vcsType: "Gitlab", + pr: 1, + projects: []string{"test"}, + paths: []struct { + Directory string + Workspace string + }{ + { + Directory: ".", + Workspace: "myworkspace", + }, + }, + }, + } + + expectedCalls := 0 + for _, c := range cases { + body, _ := json.Marshal(controllers.APIRequest{ + Repository: c.repository, + Ref: c.ref, + Type: c.vcsType, + PR: c.pr, + Projects: c.projects, + Paths: c.paths, + }) + + req, _ := http.NewRequest("POST", "", bytes.NewBuffer(body)) + req.Header.Set(atlantisTokenHeader, atlantisToken) + w := httptest.NewRecorder() + ac.Plan(w, req) + ResponseContains(t, w, http.StatusOK, "") + + expectedCalls += len(c.projects) + expectedCalls += len(c.paths) + } + + projectCommandBuilder.VerifyWasCalled(Times(expectedCalls)).BuildPlanCommands(Any[*command.Context](), Any[*events.CommentCommand]()) + projectCommandRunner.VerifyWasCalled(Times(expectedCalls)).Plan(Any[command.ProjectContext]()) } func TestAPIController_Apply(t *testing.T) { ac, projectCommandBuilder, projectCommandRunner := setup(t) - body, _ := json.Marshal(controllers.APIRequest{ - Repository: "Repo", - Ref: "main", - Type: "Gitlab", - Projects: []string{"default"}, - }) - req, _ := http.NewRequest("POST", "", bytes.NewBuffer(body)) - req.Header.Set(atlantisTokenHeader, atlantisToken) - w := httptest.NewRecorder() - ac.Apply(w, req) - ResponseContains(t, w, http.StatusOK, "") - projectCommandBuilder.VerifyWasCalledOnce().BuildApplyCommands(Any[*command.Context](), Any[*events.CommentCommand]()) - projectCommandRunner.VerifyWasCalledOnce().Plan(Any[command.ProjectContext]()) - projectCommandRunner.VerifyWasCalledOnce().Apply(Any[command.ProjectContext]()) + + cases := []struct { + repository string + ref string + vcsType string + pr int + projects []string + paths []struct { + Directory string + Workspace string + } + }{ + { + repository: "Repo", + ref: "main", + vcsType: "Gitlab", + projects: []string{"default"}, + }, + { + repository: "Repo", + ref: "main", + vcsType: "Gitlab", + pr: 1, + }, + { + repository: "Repo", + ref: "main", + vcsType: "Gitlab", + paths: []struct { + Directory string + Workspace string + }{ + { + Directory: ".", + Workspace: "myworkspace", + }, + { + Directory: "./myworkspace2", + Workspace: "myworkspace2", + }, + }, + }, + { + repository: "Repo", + ref: "main", + vcsType: "Gitlab", + pr: 1, + projects: []string{"test"}, + paths: []struct { + Directory string + Workspace string + }{ + { + Directory: ".", + Workspace: "myworkspace", + }, + }, + }, + } + + expectedCalls := 0 + for _, c := range cases { + body, _ := json.Marshal(controllers.APIRequest{ + Repository: c.repository, + Ref: c.ref, + Type: c.vcsType, + PR: c.pr, + Projects: c.projects, + Paths: c.paths, + }) + + req, _ := http.NewRequest("POST", "", bytes.NewBuffer(body)) + req.Header.Set(atlantisTokenHeader, atlantisToken) + w := httptest.NewRecorder() + ac.Apply(w, req) + ResponseContains(t, w, http.StatusOK, "") + + expectedCalls += len(c.projects) + expectedCalls += len(c.paths) + } + + projectCommandBuilder.VerifyWasCalled(Times(expectedCalls)).BuildApplyCommands(Any[*command.Context](), Any[*events.CommentCommand]()) + projectCommandRunner.VerifyWasCalled(Times(expectedCalls)).Plan(Any[command.ProjectContext]()) + projectCommandRunner.VerifyWasCalled(Times(expectedCalls)).Apply(Any[command.ProjectContext]()) } func setup(t *testing.T) (controllers.APIController, *MockProjectCommandBuilder, *MockProjectCommandRunner) { RegisterMockTestingT(t) locker := NewMockLocker() logger := logging.NewNoopLogger(t) - scope, _, _ := metrics.NewLoggingScope(logger, "null") parser := NewMockEventParsing() - vcsClient := NewMockClient() repoAllowlistChecker, err := events.NewRepoAllowlistChecker("*") + scope, _, _ := metrics.NewLoggingScope(logger, "null") + vcsClient := NewMockClient() + workingDir := NewMockWorkingDir() Ok(t, err) + workingDirLocker := NewMockWorkingDirLocker() + When(workingDirLocker.TryLock(Any[string](), Any[int](), Eq(events.DefaultWorkspace), Eq(events.DefaultRepoRelDir))). + ThenReturn(func() {}, nil) + projectCommandBuilder := NewMockProjectCommandBuilder() When(projectCommandBuilder.BuildPlanCommands(Any[*command.Context](), Any[*events.CommentCommand]())). ThenReturn([]command.ProjectContext{{ @@ -111,6 +256,8 @@ func setup(t *testing.T) (controllers.APIController, *MockProjectCommandBuilder, PostWorkflowHooksCommandRunner: postWorkflowHooksCommandRunner, VCSClient: vcsClient, RepoAllowlistChecker: repoAllowlistChecker, + WorkingDir: workingDir, + WorkingDirLocker: workingDirLocker, CommitStatusUpdater: commitStatusUpdater, } return ac, projectCommandBuilder, projectCommandRunner diff --git a/server/server.go b/server/server.go index 97363fcdf3..966e1cfa39 100644 --- a/server/server.go +++ b/server/server.go @@ -956,6 +956,8 @@ func NewServer(userConfig UserConfig, config Config) (*Server, error) { RepoAllowlistChecker: repoAllowlist, Scope: statsScope.SubScope("api"), VCSClient: vcsClient, + WorkingDir: workingDir, + WorkingDirLocker: workingDirLocker, } eventsController := &events_controllers.VCSEventsController{