-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
⚡ fetch github org repositories in parallel (#4970)
Introducing an internal go package called `workerpool` that we can use to send parallel requests when needed. :zap: For this change, I am making the fetching of repositories for an organization faster. Tested this code with an **organization that has around 3k repositories** ### Before (~2 Minutes) ``` TRC logger.FuncDur> func=provider.github.repositories took=102803.621667 ``` ### After (~5 seconds) ``` TRC logger.FuncDur> func=provider.github.repositories took=4567.576542 ``` * :zap: fetch org repositories in parallel * ⚙️ add a collector to the workerpool This will help us submit as many requests as we want without knowing about the workers. * :rotating_light: fix race conditions --------- Signed-off-by: Salim Afiune Maya <afiune@mondoo.com>
- Loading branch information
Showing
15 changed files
with
469 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// Copyright (c) Mondoo, Inc. | ||
// SPDX-License-Identifier: BUSL-1.1 | ||
|
||
package workerpool | ||
|
||
import ( | ||
"sync" | ||
"sync/atomic" | ||
) | ||
|
||
type collector[R any] struct { | ||
resultsCh <-chan R | ||
results []R | ||
read sync.Mutex | ||
|
||
errorsCh <-chan error | ||
errors []error | ||
|
||
requestsRead int64 | ||
} | ||
|
||
func (c *collector[R]) start() { | ||
go func() { | ||
for { | ||
select { | ||
case result := <-c.resultsCh: | ||
c.read.Lock() | ||
c.results = append(c.results, result) | ||
c.read.Unlock() | ||
|
||
case err := <-c.errorsCh: | ||
c.read.Lock() | ||
c.errors = append(c.errors, err) | ||
c.read.Unlock() | ||
} | ||
|
||
atomic.AddInt64(&c.requestsRead, 1) | ||
} | ||
}() | ||
} | ||
func (c *collector[R]) GetResults() []R { | ||
c.read.Lock() | ||
defer c.read.Unlock() | ||
return c.results | ||
} | ||
|
||
func (c *collector[R]) GetErrors() []error { | ||
c.read.Lock() | ||
defer c.read.Unlock() | ||
return c.errors | ||
} | ||
|
||
func (c *collector[R]) RequestsRead() int64 { | ||
return atomic.LoadInt64(&c.requestsRead) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
// Copyright (c) Mondoo, Inc. | ||
// SPDX-License-Identifier: BUSL-1.1 | ||
|
||
package workerpool | ||
|
||
import ( | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
|
||
"github.com/cockroachdb/errors" | ||
) | ||
|
||
type Task[R any] func() (result R, err error) | ||
|
||
// Pool is a generic pool of workers. | ||
type Pool[R any] struct { | ||
queueCh chan Task[R] | ||
resultsCh chan R | ||
errorsCh chan error | ||
|
||
requestsSent int64 | ||
once sync.Once | ||
|
||
workers []*worker[R] | ||
workerCount int | ||
|
||
collector[R] | ||
} | ||
|
||
// New initializes a new Pool with the provided number of workers. The pool is generic and can | ||
// accept any type of Task that returns the signature `func() (R, error)`. | ||
// | ||
// For example, a Pool[int] will accept Tasks similar to: | ||
// | ||
// task := func() (int, error) { | ||
// return 42, nil | ||
// } | ||
func New[R any](count int) *Pool[R] { | ||
resultsCh := make(chan R) | ||
errorsCh := make(chan error) | ||
return &Pool[R]{ | ||
queueCh: make(chan Task[R]), | ||
resultsCh: resultsCh, | ||
errorsCh: errorsCh, | ||
workerCount: count, | ||
collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh}, | ||
} | ||
} | ||
|
||
// Start the pool workers and collector. Make sure call `Close()` to clear the pool. | ||
// | ||
// pool := workerpool.New[int](10) | ||
// pool.Start() | ||
// defer pool.Close() | ||
func (p *Pool[R]) Start() { | ||
p.once.Do(func() { | ||
for i := 0; i < p.workerCount; i++ { | ||
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} | ||
w.start() | ||
p.workers = append(p.workers, &w) | ||
} | ||
|
||
p.collector.start() | ||
}) | ||
} | ||
|
||
// Submit sends a task to the workers | ||
func (p *Pool[R]) Submit(t Task[R]) { | ||
p.queueCh <- t | ||
atomic.AddInt64(&p.requestsSent, 1) | ||
} | ||
|
||
// GetErrors returns any error from a processed task | ||
func (p *Pool[R]) GetErrors() error { | ||
return errors.Join(p.collector.GetErrors()...) | ||
} | ||
|
||
// GetResults returns the tasks results. | ||
// | ||
// It is recommended to call `Wait()` before reading the results. | ||
func (p *Pool[R]) GetResults() []R { | ||
return p.collector.GetResults() | ||
} | ||
|
||
// Close waits for workers and collector to process all the requests, and then closes | ||
// the task queue channel. After closing the pool, calling `Submit()` will panic. | ||
func (p *Pool[R]) Close() { | ||
p.Wait() | ||
close(p.queueCh) | ||
} | ||
|
||
// Wait waits until all tasks have been processed. | ||
func (p *Pool[R]) Wait() { | ||
ticker := time.NewTicker(100 * time.Millisecond) | ||
for { | ||
if !p.Processing() { | ||
return | ||
} | ||
<-ticker.C | ||
} | ||
} | ||
|
||
// PendingRequests returns the number of pending requests. | ||
func (p *Pool[R]) PendingRequests() int64 { | ||
return atomic.LoadInt64(&p.requestsSent) - p.collector.RequestsRead() | ||
} | ||
|
||
// Processing return true if tasks are being processed. | ||
func (p *Pool[R]) Processing() bool { | ||
return p.PendingRequests() != 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
// Copyright (c) Mondoo, Inc. | ||
// SPDX-License-Identifier: BUSL-1.1 | ||
|
||
package workerpool_test | ||
|
||
import ( | ||
"errors" | ||
"testing" | ||
"time" | ||
|
||
"math/rand" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"go.mondoo.com/cnquery/v11/internal/workerpool" | ||
) | ||
|
||
func TestPoolSubmitAndRetrieveResult(t *testing.T) { | ||
pool := workerpool.New[int](2) | ||
pool.Start() | ||
defer pool.Close() | ||
|
||
task := func() (int, error) { | ||
return 42, nil | ||
} | ||
|
||
// no results | ||
assert.Empty(t, pool.GetResults()) | ||
|
||
// submit a request | ||
pool.Submit(task) | ||
|
||
// wait for the request to process | ||
pool.Wait() | ||
|
||
// should have one result | ||
results := pool.GetResults() | ||
if assert.Len(t, results, 1) { | ||
assert.Equal(t, 42, results[0]) | ||
} | ||
|
||
// no errors | ||
assert.Nil(t, pool.GetErrors()) | ||
} | ||
|
||
func TestPoolHandleErrors(t *testing.T) { | ||
pool := workerpool.New[int](5) | ||
pool.Start() | ||
defer pool.Close() | ||
|
||
// submit a task that will return an error | ||
task := func() (int, error) { | ||
return 0, errors.New("task error") | ||
} | ||
pool.Submit(task) | ||
|
||
// Wait for error collector to process | ||
pool.Wait() | ||
|
||
err := pool.GetErrors() | ||
if assert.Error(t, err) { | ||
assert.Contains(t, err.Error(), "task error") | ||
} | ||
} | ||
|
||
func TestPoolMultipleTasksWithErrors(t *testing.T) { | ||
type test struct { | ||
data int | ||
} | ||
pool := workerpool.New[*test](5) | ||
pool.Start() | ||
defer pool.Close() | ||
|
||
tasks := []workerpool.Task[*test]{ | ||
func() (*test, error) { return &test{1}, nil }, | ||
func() (*test, error) { return &test{2}, nil }, | ||
func() (*test, error) { | ||
return nil, errors.New("task error") | ||
}, | ||
func() (*test, error) { return &test{3}, nil }, | ||
} | ||
|
||
for _, task := range tasks { | ||
pool.Submit(task) | ||
} | ||
|
||
// Wait for error collector to process | ||
pool.Wait() | ||
|
||
results := pool.GetResults() | ||
assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results) | ||
err := pool.GetErrors() | ||
if assert.Error(t, err) { | ||
assert.Contains(t, err.Error(), "task error") | ||
} | ||
} | ||
|
||
func TestPoolHandlesNilTasks(t *testing.T) { | ||
pool := workerpool.New[int](2) | ||
pool.Start() | ||
defer pool.Close() | ||
|
||
var nilTask workerpool.Task[int] | ||
pool.Submit(nilTask) | ||
|
||
pool.Wait() | ||
|
||
err := pool.GetErrors() | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestPoolProcessing(t *testing.T) { | ||
pool := workerpool.New[int](2) | ||
pool.Start() | ||
defer pool.Close() | ||
|
||
task := func() (int, error) { | ||
time.Sleep(50 * time.Millisecond) | ||
return 10, nil | ||
} | ||
|
||
pool.Submit(task) | ||
|
||
// should be processing | ||
assert.True(t, pool.Processing()) | ||
|
||
// wait | ||
pool.Wait() | ||
|
||
// read results | ||
result := pool.GetResults() | ||
assert.Equal(t, []int{10}, result) | ||
|
||
// should not longer be processing | ||
assert.False(t, pool.Processing()) | ||
} | ||
|
||
func TestPoolClosesGracefully(t *testing.T) { | ||
pool := workerpool.New[int](1) | ||
pool.Start() | ||
|
||
task := func() (int, error) { | ||
time.Sleep(100 * time.Millisecond) | ||
return 42, nil | ||
} | ||
|
||
pool.Submit(task) | ||
|
||
pool.Close() | ||
|
||
// Ensure no panic occurs and channels are closed | ||
assert.PanicsWithError(t, "send on closed channel", func() { | ||
pool.Submit(task) | ||
}) | ||
} | ||
|
||
func TestPoolWithManyTasks(t *testing.T) { | ||
// 30k requests with a pool of 100 workers | ||
// should be around 15 seconds | ||
requestCount := 30000 | ||
pool := workerpool.New[int](100) | ||
pool.Start() | ||
defer pool.Close() | ||
|
||
task := func() (int, error) { | ||
random := rand.Intn(100) | ||
time.Sleep(time.Duration(random) * time.Millisecond) | ||
return random, nil | ||
} | ||
|
||
for i := 0; i < requestCount; i++ { | ||
pool.Submit(task) | ||
} | ||
|
||
// should be processing | ||
assert.True(t, pool.Processing()) | ||
|
||
// wait | ||
pool.Wait() | ||
|
||
// read results | ||
assert.Equal(t, requestCount, len(pool.GetResults())) | ||
|
||
// should not longer be processing | ||
assert.False(t, pool.Processing()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Copyright (c) Mondoo, Inc. | ||
// SPDX-License-Identifier: BUSL-1.1 | ||
|
||
package workerpool | ||
|
||
type worker[R any] struct { | ||
id int | ||
queueCh <-chan Task[R] | ||
resultsCh chan<- R | ||
errorsCh chan<- error | ||
} | ||
|
||
func (w *worker[R]) start() { | ||
go func() { | ||
for task := range w.queueCh { | ||
if task == nil { | ||
// let the collector know we processed the request | ||
w.errorsCh <- nil | ||
continue | ||
} | ||
|
||
data, err := task() | ||
if err != nil { | ||
w.errorsCh <- err | ||
} else { | ||
w.resultsCh <- data | ||
} | ||
} | ||
}() | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.