Skip to content

Commit

Permalink
fix: ensure cloning workingdir before calling plan api (#3584)
Browse files Browse the repository at this point in the history
Signed-off-by: Hajime Terasawa <terako.studio@gmail.com>
Co-authored-by: PePe Amengual <2208324+jamengual@users.noreply.github.com>
  • Loading branch information
terakoya76 and jamengual authored Feb 5, 2025
1 parent 38ba386 commit 766dac9
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 31 deletions.
39 changes: 37 additions & 2 deletions server/controllers/api_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type APIController struct {
RepoAllowlistChecker *events.RepoAllowlistChecker
Scope tally.Scope
VCSClient vcs.Client
WorkingDir events.WorkingDir
WorkingDirLocker events.WorkingDirLocker
CommitStatusUpdater events.CommitStatusUpdater
}

Expand Down Expand Up @@ -91,12 +93,18 @@ func (a *APIController) Plan(w http.ResponseWriter, r *http.Request) {
return
}

err = a.apiSetup(ctx)
if err != nil {
a.apiReportError(w, http.StatusInternalServerError, err)
return
}

result, err := a.apiPlan(request, ctx)
if err != nil {
a.apiReportError(w, http.StatusInternalServerError, err)
return
}
defer a.Locker.UnlockByPull(ctx.HeadRepo.FullName, 0) // nolint: errcheck
defer a.Locker.UnlockByPull(ctx.HeadRepo.FullName, ctx.Pull.Num) // nolint: errcheck
if result.HasErrors() {
code = http.StatusInternalServerError
}
Expand All @@ -119,13 +127,19 @@ func (a *APIController) Apply(w http.ResponseWriter, r *http.Request) {
return
}

err = a.apiSetup(ctx)
if err != nil {
a.apiReportError(w, http.StatusInternalServerError, err)
return
}

// We must first make the plan for all projects
_, err = a.apiPlan(request, ctx)
if err != nil {
a.apiReportError(w, http.StatusInternalServerError, err)
return
}
defer a.Locker.UnlockByPull(ctx.HeadRepo.FullName, 0) // nolint: errcheck
defer a.Locker.UnlockByPull(ctx.HeadRepo.FullName, ctx.Pull.Num) // nolint: errcheck

// We can now prepare and run the apply step
result, err := a.apiApply(request, ctx)
Expand All @@ -145,6 +159,27 @@ func (a *APIController) Apply(w http.ResponseWriter, r *http.Request) {
a.respond(w, logging.Warn, code, "%s", string(response))
}

func (a *APIController) apiSetup(ctx *command.Context) error {
pull := ctx.Pull
baseRepo := ctx.Pull.BaseRepo
headRepo := ctx.HeadRepo

unlockFn, err := a.WorkingDirLocker.TryLock(baseRepo.FullName, pull.Num, events.DefaultWorkspace, events.DefaultRepoRelDir)
if err != nil {
return err
}
ctx.Log.Debug("got workspace lock")
defer unlockFn()

// ensure workingDir is present
_, _, err = a.WorkingDir.Clone(ctx.Log, headRepo, pull, events.DefaultWorkspace)
if err != nil {
return err
}

return nil
}

func (a *APIController) apiPlan(request *APIRequest, ctx *command.Context) (*command.Result, error) {
cmds, cc, err := request.getCommands(ctx, a.ProjectCommandBuilder.BuildPlanCommands)
if err != nil {
Expand Down
205 changes: 176 additions & 29 deletions server/controllers/api_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,49 +25,194 @@ const atlantisToken = "token"

func TestAPIController_Plan(t *testing.T) {
ac, projectCommandBuilder, projectCommandRunner := setup(t)
body, _ := json.Marshal(controllers.APIRequest{
Repository: "Repo",
Ref: "main",
Type: "Gitlab",
Projects: []string{"default"},
})
req, _ := http.NewRequest("POST", "", bytes.NewBuffer(body))
req.Header.Set(atlantisTokenHeader, atlantisToken)
w := httptest.NewRecorder()
ac.Plan(w, req)
ResponseContains(t, w, http.StatusOK, "")
projectCommandBuilder.VerifyWasCalledOnce().BuildPlanCommands(Any[*command.Context](), Any[*events.CommentCommand]())
projectCommandRunner.VerifyWasCalledOnce().Plan(Any[command.ProjectContext]())

cases := []struct {
repository string
ref string
vcsType string
pr int
projects []string
paths []struct {
Directory string
Workspace string
}
}{
{
repository: "Repo",
ref: "main",
vcsType: "Gitlab",
projects: []string{"default"},
},
{
repository: "Repo",
ref: "main",
vcsType: "Gitlab",
pr: 1,
},
{
repository: "Repo",
ref: "main",
vcsType: "Gitlab",
paths: []struct {
Directory string
Workspace string
}{
{
Directory: ".",
Workspace: "myworkspace",
},
{
Directory: "./myworkspace2",
Workspace: "myworkspace2",
},
},
},
{
repository: "Repo",
ref: "main",
vcsType: "Gitlab",
pr: 1,
projects: []string{"test"},
paths: []struct {
Directory string
Workspace string
}{
{
Directory: ".",
Workspace: "myworkspace",
},
},
},
}

expectedCalls := 0
for _, c := range cases {
body, _ := json.Marshal(controllers.APIRequest{
Repository: c.repository,
Ref: c.ref,
Type: c.vcsType,
PR: c.pr,
Projects: c.projects,
Paths: c.paths,
})

req, _ := http.NewRequest("POST", "", bytes.NewBuffer(body))
req.Header.Set(atlantisTokenHeader, atlantisToken)
w := httptest.NewRecorder()
ac.Plan(w, req)
ResponseContains(t, w, http.StatusOK, "")

expectedCalls += len(c.projects)
expectedCalls += len(c.paths)
}

projectCommandBuilder.VerifyWasCalled(Times(expectedCalls)).BuildPlanCommands(Any[*command.Context](), Any[*events.CommentCommand]())
projectCommandRunner.VerifyWasCalled(Times(expectedCalls)).Plan(Any[command.ProjectContext]())
}

func TestAPIController_Apply(t *testing.T) {
ac, projectCommandBuilder, projectCommandRunner := setup(t)
body, _ := json.Marshal(controllers.APIRequest{
Repository: "Repo",
Ref: "main",
Type: "Gitlab",
Projects: []string{"default"},
})
req, _ := http.NewRequest("POST", "", bytes.NewBuffer(body))
req.Header.Set(atlantisTokenHeader, atlantisToken)
w := httptest.NewRecorder()
ac.Apply(w, req)
ResponseContains(t, w, http.StatusOK, "")
projectCommandBuilder.VerifyWasCalledOnce().BuildApplyCommands(Any[*command.Context](), Any[*events.CommentCommand]())
projectCommandRunner.VerifyWasCalledOnce().Plan(Any[command.ProjectContext]())
projectCommandRunner.VerifyWasCalledOnce().Apply(Any[command.ProjectContext]())

cases := []struct {
repository string
ref string
vcsType string
pr int
projects []string
paths []struct {
Directory string
Workspace string
}
}{
{
repository: "Repo",
ref: "main",
vcsType: "Gitlab",
projects: []string{"default"},
},
{
repository: "Repo",
ref: "main",
vcsType: "Gitlab",
pr: 1,
},
{
repository: "Repo",
ref: "main",
vcsType: "Gitlab",
paths: []struct {
Directory string
Workspace string
}{
{
Directory: ".",
Workspace: "myworkspace",
},
{
Directory: "./myworkspace2",
Workspace: "myworkspace2",
},
},
},
{
repository: "Repo",
ref: "main",
vcsType: "Gitlab",
pr: 1,
projects: []string{"test"},
paths: []struct {
Directory string
Workspace string
}{
{
Directory: ".",
Workspace: "myworkspace",
},
},
},
}

expectedCalls := 0
for _, c := range cases {
body, _ := json.Marshal(controllers.APIRequest{
Repository: c.repository,
Ref: c.ref,
Type: c.vcsType,
PR: c.pr,
Projects: c.projects,
Paths: c.paths,
})

req, _ := http.NewRequest("POST", "", bytes.NewBuffer(body))
req.Header.Set(atlantisTokenHeader, atlantisToken)
w := httptest.NewRecorder()
ac.Apply(w, req)
ResponseContains(t, w, http.StatusOK, "")

expectedCalls += len(c.projects)
expectedCalls += len(c.paths)
}

projectCommandBuilder.VerifyWasCalled(Times(expectedCalls)).BuildApplyCommands(Any[*command.Context](), Any[*events.CommentCommand]())
projectCommandRunner.VerifyWasCalled(Times(expectedCalls)).Plan(Any[command.ProjectContext]())
projectCommandRunner.VerifyWasCalled(Times(expectedCalls)).Apply(Any[command.ProjectContext]())
}

func setup(t *testing.T) (controllers.APIController, *MockProjectCommandBuilder, *MockProjectCommandRunner) {
RegisterMockTestingT(t)
locker := NewMockLocker()
logger := logging.NewNoopLogger(t)
scope, _, _ := metrics.NewLoggingScope(logger, "null")
parser := NewMockEventParsing()
vcsClient := NewMockClient()
repoAllowlistChecker, err := events.NewRepoAllowlistChecker("*")
scope, _, _ := metrics.NewLoggingScope(logger, "null")
vcsClient := NewMockClient()
workingDir := NewMockWorkingDir()
Ok(t, err)

workingDirLocker := NewMockWorkingDirLocker()
When(workingDirLocker.TryLock(Any[string](), Any[int](), Eq(events.DefaultWorkspace), Eq(events.DefaultRepoRelDir))).
ThenReturn(func() {}, nil)

projectCommandBuilder := NewMockProjectCommandBuilder()
When(projectCommandBuilder.BuildPlanCommands(Any[*command.Context](), Any[*events.CommentCommand]())).
ThenReturn([]command.ProjectContext{{
Expand Down Expand Up @@ -111,6 +256,8 @@ func setup(t *testing.T) (controllers.APIController, *MockProjectCommandBuilder,
PostWorkflowHooksCommandRunner: postWorkflowHooksCommandRunner,
VCSClient: vcsClient,
RepoAllowlistChecker: repoAllowlistChecker,
WorkingDir: workingDir,
WorkingDirLocker: workingDirLocker,
CommitStatusUpdater: commitStatusUpdater,
}
return ac, projectCommandBuilder, projectCommandRunner
Expand Down
2 changes: 2 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,8 @@ func NewServer(userConfig UserConfig, config Config) (*Server, error) {
RepoAllowlistChecker: repoAllowlist,
Scope: statsScope.SubScope("api"),
VCSClient: vcsClient,
WorkingDir: workingDir,
WorkingDirLocker: workingDirLocker,
}

eventsController := &events_controllers.VCSEventsController{
Expand Down

0 comments on commit 766dac9

Please sign in to comment.