diff --git a/server/controllers/events/events_controller_e2e_test.go b/server/controllers/events/events_controller_e2e_test.go index cd844fb33..4a1320eb2 100644 --- a/server/controllers/events/events_controller_e2e_test.go +++ b/server/controllers/events/events_controller_e2e_test.go @@ -1059,6 +1059,7 @@ func setupE2E(t *testing.T, repoDir string) (events_controllers.VCSEventsControl GitlabMergeRequestGetter: e2eGitlabGetter, Logger: logger, StatsScope: statsScope, + GlobalCfg: globalCfg, AllowForkPRs: allowForkPRs, AllowForkPRsFlag: "allow-fork-prs", CommentCommandRunnerByCmd: commentCommandRunnerByCmd, diff --git a/server/events/command_runner.go b/server/events/command_runner.go index 8028a1cfd..66274bdd2 100644 --- a/server/events/command_runner.go +++ b/server/events/command_runner.go @@ -24,6 +24,7 @@ import ( "github.com/runatlantis/atlantis/server/events/metrics" "github.com/runatlantis/atlantis/server/events/models" "github.com/runatlantis/atlantis/server/events/vcs" + "github.com/runatlantis/atlantis/server/events/yaml/valid" "github.com/runatlantis/atlantis/server/feature" "github.com/runatlantis/atlantis/server/logging" "github.com/runatlantis/atlantis/server/recovery" @@ -99,6 +100,7 @@ type DefaultCommandRunner struct { EventParser EventParsing Logger logging.SimpleLogging StatsScope stats.Scope + GlobalCfg valid.GlobalCfg // AllowForkPRs controls whether we operate on pull requests from forks. AllowForkPRs bool // ParallelPoolSize controls the size of the wait group used to run @@ -339,6 +341,13 @@ func (c *DefaultCommandRunner) validateCtxAndComment(ctx *CommandContext) bool { } return false } + + repo := c.GlobalCfg.MatchingRepo(ctx.Pull.BaseRepo.ID()) + if !repo.BranchMatches(ctx.Pull.BaseBranch) { + ctx.Log.Info("command was run on a pull request which doesn't match base branches") + // just ignore it to allow us to use any git workflows without malicious intentions. + return false + } return true } diff --git a/server/events/command_runner_test.go b/server/events/command_runner_test.go index 1068ff1ac..f38b7c721 100644 --- a/server/events/command_runner_test.go +++ b/server/events/command_runner_test.go @@ -16,6 +16,7 @@ package events_test import ( "errors" "fmt" + "regexp" "strings" "testing" @@ -193,6 +194,7 @@ func setup(t *testing.T) *vcsmocks.MockClient { When(preWorkflowHooksCommandRunner.RunPreHooks(matchers.AnyPtrToEventsCommandContext())).ThenReturn(nil) scope := stats.NewDefaultStore() + globalCfg := valid.NewGlobalCfgFromArgs(valid.GlobalCfgArgs{}) ch = events.DefaultCommandRunner{ VCSClient: vcsClient, @@ -203,6 +205,7 @@ func setup(t *testing.T) *vcsmocks.MockClient { AzureDevopsPullGetter: azuredevopsGetter, Logger: logger, StatsScope: scope, + GlobalCfg: globalCfg, AllowForkPRs: false, AllowForkPRsFlag: "allow-fork-prs-flag", Drainer: drainer, @@ -457,6 +460,40 @@ func TestRunCommentCommand_ClosedPull(t *testing.T) { vcsClient.VerifyWasCalledOnce().CreateComment(fixtures.GithubRepo, modelPull.Num, "Atlantis commands can't be run on closed pull requests", "") } +func TestRunCommentCommand_MatchedBranch(t *testing.T) { + t.Log("if a command is run on a pull request which matches base branches run plan successfully") + vcsClient := setup(t) + + ch.GlobalCfg.Repos = append(ch.GlobalCfg.Repos, valid.Repo{ + IDRegex: regexp.MustCompile(".*"), + BranchRegex: regexp.MustCompile("^main$"), + }) + var pull github.PullRequest + modelPull := models.PullRequest{BaseRepo: fixtures.GithubRepo, BaseBranch: "main"} + When(githubGetter.GetPullRequest(fixtures.GithubRepo, fixtures.Pull.Num)).ThenReturn(&pull, nil) + When(eventParsing.ParseGithubPull(&pull)).ThenReturn(modelPull, modelPull.BaseRepo, fixtures.GithubRepo, nil) + + ch.RunCommentCommand(fixtures.GithubRepo, nil, nil, fixtures.User, fixtures.Pull.Num, &events.CommentCommand{Name: models.PlanCommand}) + vcsClient.VerifyWasCalledOnce().CreateComment(fixtures.GithubRepo, modelPull.Num, "Ran Plan for 0 projects:\n\n\n\n", "plan") +} + +func TestRunCommentCommand_UnmatchedBranch(t *testing.T) { + t.Log("if a command is run on a pull request which doesn't match base branches do not comment with error") + vcsClient := setup(t) + + ch.GlobalCfg.Repos = append(ch.GlobalCfg.Repos, valid.Repo{ + IDRegex: regexp.MustCompile(".*"), + BranchRegex: regexp.MustCompile("^main$"), + }) + var pull github.PullRequest + modelPull := models.PullRequest{BaseRepo: fixtures.GithubRepo, BaseBranch: "foo"} + When(githubGetter.GetPullRequest(fixtures.GithubRepo, fixtures.Pull.Num)).ThenReturn(&pull, nil) + When(eventParsing.ParseGithubPull(&pull)).ThenReturn(modelPull, modelPull.BaseRepo, fixtures.GithubRepo, nil) + + ch.RunCommentCommand(fixtures.GithubRepo, nil, nil, fixtures.User, fixtures.Pull.Num, &events.CommentCommand{Name: models.PlanCommand}) + vcsClient.VerifyWasCalled(Never()).CreateComment(matchers.AnyModelsRepo(), AnyInt(), AnyString(), AnyString()) +} + func TestRunUnlockCommand_VCSComment(t *testing.T) { t.Log("if unlock PR command is run, atlantis should" + " invoke the delete command and comment on PR accordingly") diff --git a/server/events/pre_workflow_hooks_command_runner.go b/server/events/pre_workflow_hooks_command_runner.go index 1f110db1d..7619fd9a8 100644 --- a/server/events/pre_workflow_hooks_command_runner.go +++ b/server/events/pre_workflow_hooks_command_runner.go @@ -34,7 +34,7 @@ func (w *DefaultPreWorkflowHooksCommandRunner) RunPreHooks( preWorkflowHooks := make([]*valid.PreWorkflowHook, 0) for _, repo := range w.GlobalCfg.Repos { - if repo.IDMatches(baseRepo.ID()) && repo.BranchMatches(pull.BaseBranch) && len(repo.PreWorkflowHooks) > 0 { + if repo.IDMatches(baseRepo.ID()) && len(repo.PreWorkflowHooks) > 0 { preWorkflowHooks = append(preWorkflowHooks, repo.PreWorkflowHooks...) } } diff --git a/server/events/yaml/valid/global_cfg.go b/server/events/yaml/valid/global_cfg.go index 6b887680a..2290a2765 100644 --- a/server/events/yaml/valid/global_cfg.go +++ b/server/events/yaml/valid/global_cfg.go @@ -472,3 +472,15 @@ func (g GlobalCfg) getMatchingCfg(log logging.SimpleLogging, repoID string) (app } return } + +// MatchingRepo returns an instance of Repo which matches a given repoID. +// If multiple repos match, return the last one for consistency with getMatchingCfg. +func (g GlobalCfg) MatchingRepo(repoID string) *Repo { + for i := len(g.Repos) - 1; i >= 0; i-- { + repo := g.Repos[i] + if repo.IDMatches(repoID) { + return &repo + } + } + return nil +} diff --git a/server/events/yaml/valid/global_cfg_test.go b/server/events/yaml/valid/global_cfg_test.go index 475cf7dfa..f317a8101 100644 --- a/server/events/yaml/valid/global_cfg_test.go +++ b/server/events/yaml/valid/global_cfg_test.go @@ -892,6 +892,69 @@ func TestRepo_BranchMatches(t *testing.T) { Equals(t, false, (valid.Repo{BranchRegex: regexp.MustCompile("release")}).BranchMatches("main")) } +func TestGlobalCfg_MatchingRepo(t *testing.T) { + defaultRepo := valid.Repo{ + IDRegex: regexp.MustCompile(".*"), + BranchRegex: regexp.MustCompile(".*"), + ApplyRequirements: []string{}, + } + repo1 := valid.Repo{ + IDRegex: regexp.MustCompile(".*"), + BranchRegex: regexp.MustCompile("^main$"), + ApplyRequirements: []string{"approved"}, + } + repo2 := valid.Repo{ + ID: "github.com/owner/repo", + BranchRegex: regexp.MustCompile("^master$"), + ApplyRequirements: []string{"approved", "mergeable"}, + } + + cases := map[string]struct { + gCfg valid.GlobalCfg + repoID string + exp *valid.Repo + }{ + "matches to default": { + gCfg: valid.GlobalCfg{ + Repos: []valid.Repo{ + defaultRepo, + repo2, + }, + }, + repoID: "foo", + exp: &defaultRepo, + }, + "matches to IDRegex": { + gCfg: valid.GlobalCfg{ + Repos: []valid.Repo{ + defaultRepo, + repo1, + repo2, + }, + }, + repoID: "foo", + exp: &repo1, + }, + "matches to ID": { + gCfg: valid.GlobalCfg{ + Repos: []valid.Repo{ + defaultRepo, + repo1, + repo2, + }, + }, + repoID: "github.com/owner/repo", + exp: &repo2, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + Equals(t, c.exp, c.gCfg.MatchingRepo(c.repoID)) + }) + } +} + // String is a helper routine that allocates a new string value // to store v and returns a pointer to it. func String(v string) *string { return &v } diff --git a/server/server.go b/server/server.go index 8185bc056..551ea4c4e 100644 --- a/server/server.go +++ b/server/server.go @@ -702,6 +702,7 @@ func NewServer(userConfig UserConfig, config Config) (*Server, error) { EventParser: eventParser, Logger: logger, StatsScope: statsScope.Scope("cmd"), + GlobalCfg: globalCfg, AllowForkPRs: userConfig.AllowForkPRs, AllowForkPRsFlag: config.AllowForkPRsFlag, SilenceForkPRErrors: userConfig.SilenceForkPRErrors,