Skip to content

Commit

Permalink
Parameterize max goroutines
Browse files Browse the repository at this point in the history
Signed-off-by: Raghav Kaul <raghavkaul@google.com>
  • Loading branch information
raghavkaul committed Jan 4, 2024
1 parent 00e8917 commit 778e4ff
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
6 changes: 4 additions & 2 deletions cmd/allstar/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func main() {
specificPolicyArg := flag.String("policy", "", fmt.Sprintf("Run a specific policy check. Supported policies: %s", supportedPoliciesMsg))
specificRepoArg := flag.String("repo", "", "Run on a specific \"owner/repo\". For example \"ossf/allstar\"")

numWorkersArg := flag.Int("workers", 5, "maximum number of active goroutines for Allstar scans")

flag.Parse()

if *specificPolicyArg != "" {
Expand All @@ -81,7 +83,7 @@ func main() {
}

if runOnce {
_, err := enforce.EnforceAll(ctx, ghc, *specificPolicyArg, *specificRepoArg)
_, err := enforce.EnforceAll(ctx, ghc, *specificPolicyArg, *specificRepoArg, *numWorkersArg)
if err != nil {
log.Fatal().
Err(err).
Expand All @@ -94,7 +96,7 @@ func main() {
go func() {
defer wg.Done()
log.Info().
Err(enforce.EnforceJob(ctx, ghc, (5 * time.Minute), *specificPolicyArg, *specificRepoArg)).
Err(enforce.EnforceJob(ctx, ghc, (5 * time.Minute), *specificPolicyArg, *specificRepoArg, *numWorkersArg)).
Msg("Enforce job shutting down.")
}()
sigs := make(chan os.Signal, 1)
Expand Down
8 changes: 4 additions & 4 deletions pkg/enforce/enforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func init() {
//
// TBD: determine if this should remain exported, or if it will only be called
// from EnforceJob.
func EnforceAll(ctx context.Context, ghc ghclients.GhClientsInterface, specificPolicyArg string, specificRepoArg string) (EnforceAllResults, error) {
func EnforceAll(ctx context.Context, ghc ghclients.GhClientsInterface, specificPolicyArg string, specificRepoArg string, numWorkersArg int) (EnforceAllResults, error) {
var repoCount int
var enforceAllResults = make(EnforceAllResults)
ac, err := ghc.Get(0)
Expand All @@ -85,7 +85,7 @@ func EnforceAll(ctx context.Context, ghc ghclients.GhClientsInterface, specificP
Msg("Enforcing policies on installations.")

g, ctx := errgroup.WithContext(ctx)
g.SetLimit(5)
g.SetLimit(numWorkersArg)
var mu sync.Mutex

for _, i := range insts {
Expand Down Expand Up @@ -302,9 +302,9 @@ func getAppInstallationReposReal(ctx context.Context, ic *github.Client) ([]*git

// EnforceJob is a reconciliation job that enforces policies on all repos every
// d duration. It runs forever until the context is done.
func EnforceJob(ctx context.Context, ghc *ghclients.GHClients, d time.Duration, specificPolicyArg string, specificRepoArg string) error {
func EnforceJob(ctx context.Context, ghc *ghclients.GHClients, d time.Duration, specificPolicyArg string, specificRepoArg string, numWorkersArg int) error {
for {
_, err := EnforceAll(ctx, ghc, specificPolicyArg, specificRepoArg)
_, err := EnforceAll(ctx, ghc, specificPolicyArg, specificRepoArg, numWorkersArg)
if err != nil {
log.Error().
Err(err).
Expand Down
8 changes: 5 additions & 3 deletions pkg/enforce/enforce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,8 @@ func TestEnforceAll(t *testing.T) {
policy1Results = test.Policy1Results
policy2Results = test.Policy2Results

enforceAllResults, err := EnforceAll(context.Background(), mockGhc, "", "")
numWorkers := 1
enforceAllResults, err := EnforceAll(context.Background(), mockGhc, "", "", numWorkers)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
Expand Down Expand Up @@ -581,15 +582,16 @@ func TestSuspendedEnforce(t *testing.T) {
}
suspended = false
gaicalled = false
if _, err := EnforceAll(context.Background(), &MockGhClients{}, "", ""); err != nil {
numWorkers := 1
if _, err := EnforceAll(context.Background(), &MockGhClients{}, "", "", numWorkers); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !gaicalled {
t.Errorf("Expected getAppInstallationRepos() to be called, but wasn't")
}
suspended = true
gaicalled = false
if _, err := EnforceAll(context.Background(), &MockGhClients{}, "", ""); err != nil {
if _, err := EnforceAll(context.Background(), &MockGhClients{}, "", "", numWorkers); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if gaicalled {
Expand Down

0 comments on commit 778e4ff

Please sign in to comment.