Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Goroutine pooling #31

Merged
merged 9 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
go-version-file: ./go.mod
- name: Lint
run: |
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.55.1
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.61.0
./bin/golangci-lint run --verbose

test-linux-race:
Expand Down
21 changes: 12 additions & 9 deletions engine_null.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ import (

"github.com/google/cel-go/common/operators"
"github.com/ohler55/ojg/jp"
"golang.org/x/sync/errgroup"
)

func newNullMatcher() MatchingEngine {
func newNullMatcher(concurrency int64) MatchingEngine {
return &nullLookup{
lock: &sync.RWMutex{},
paths: map[string]struct{}{},
null: map[string][]*StoredExpressionPart{},
not: map[string][]*StoredExpressionPart{},
lock: &sync.RWMutex{},
paths: map[string]struct{}{},
null: map[string][]*StoredExpressionPart{},
not: map[string][]*StoredExpressionPart{},
concurrency: concurrency,
}
}

Expand All @@ -26,6 +26,8 @@ type nullLookup struct {

null map[string][]*StoredExpressionPart
not map[string][]*StoredExpressionPart

concurrency int64
}

func (n *nullLookup) Type() EngineType {
Expand All @@ -35,11 +37,12 @@ func (n *nullLookup) Type() EngineType {
func (n *nullLookup) Match(ctx context.Context, data map[string]any) (matched []*StoredExpressionPart, err error) {
l := &sync.Mutex{}
matched = []*StoredExpressionPart{}
eg := errgroup.Group{}

pool := newErrPool(errPoolOpts{concurrency: n.concurrency})

for item := range n.paths {
path := item
eg.Go(func() error {
pool.Go(func() error {
x, err := jp.ParseString(path)
if err != nil {
return err
Expand All @@ -65,7 +68,7 @@ func (n *nullLookup) Match(ctx context.Context, data map[string]any) (matched []
})
}

return matched, eg.Wait()
return matched, pool.Wait()
}

func (n *nullLookup) Search(ctx context.Context, variable string, input any) (matched []*StoredExpressionPart) {
Expand Down
15 changes: 9 additions & 6 deletions engine_number.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import (
"github.com/google/cel-go/common/operators"
"github.com/ohler55/ojg/jp"
"github.com/tidwall/btree"
"golang.org/x/sync/errgroup"
)

func newNumberMatcher() MatchingEngine {
func newNumberMatcher(concurrency int64) MatchingEngine {
return &numbers{
lock: &sync.RWMutex{},

paths: map[string]struct{}{},
paths: map[string]struct{}{},
concurrency: concurrency,

exact: btree.NewMap[float64, []*StoredExpressionPart](64),
gt: btree.NewMap[float64, []*StoredExpressionPart](64),
Expand All @@ -33,6 +33,8 @@ type numbers struct {
exact *btree.Map[float64, []*StoredExpressionPart]
gt *btree.Map[float64, []*StoredExpressionPart]
lt *btree.Map[float64, []*StoredExpressionPart]

concurrency int64
}

func (n numbers) Type() EngineType {
Expand All @@ -42,11 +44,12 @@ func (n numbers) Type() EngineType {
func (n *numbers) Match(ctx context.Context, input map[string]any) (matched []*StoredExpressionPart, err error) {
l := &sync.Mutex{}
matched = []*StoredExpressionPart{}
eg := errgroup.Group{}

pool := newErrPool(errPoolOpts{concurrency: n.concurrency})

for item := range n.paths {
path := item
eg.Go(func() error {
pool.Go(func() error {
x, err := jp.ParseString(path)
if err != nil {
return err
Expand Down Expand Up @@ -80,7 +83,7 @@ func (n *numbers) Match(ctx context.Context, input map[string]any) (matched []*S
})
}

return matched, eg.Wait()
return matched, pool.Wait()
}

// Search returns all ExpressionParts which match the given input, ignoring the variable name
Expand Down
4 changes: 3 additions & 1 deletion engine_number_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (
"github.com/stretchr/testify/require"
)

const testConcurrency = 100

func TestEngineNumber(t *testing.T) {
ctx := context.Background()
n := newNumberMatcher().(*numbers)
n := newNumberMatcher(testConcurrency).(*numbers)

// int64
a := ExpressionPart{
Expand Down
23 changes: 13 additions & 10 deletions engine_stringmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
"github.com/cespare/xxhash/v2"
"github.com/google/cel-go/common/operators"
"github.com/ohler55/ojg/jp"
"golang.org/x/sync/errgroup"
)

func newStringEqualityMatcher() MatchingEngine {
func newStringEqualityMatcher(concurrency int64) MatchingEngine {
return &stringLookup{
lock: &sync.RWMutex{},
vars: map[string]struct{}{},
equality: variableMap{},
inequality: inequalityMap{},
lock: &sync.RWMutex{},
vars: map[string]struct{}{},
equality: variableMap{},
inequality: inequalityMap{},
concurrency: concurrency,
}
}

Expand Down Expand Up @@ -51,6 +51,8 @@ type stringLookup struct {
//
// this lets us quickly map neq in a fast manner
inequality inequalityMap

concurrency int64
}

func (s stringLookup) Type() EngineType {
Expand All @@ -61,12 +63,13 @@ func (n *stringLookup) Match(ctx context.Context, input map[string]any) ([]*Stor
l := &sync.Mutex{}

matched := []*StoredExpressionPart{}
eg := errgroup.Group{}

pool := newErrPool(errPoolOpts{concurrency: n.concurrency})

// First, handle equality matching.
for item := range n.vars {
path := item
eg.Go(func() error {
pool.Go(func() error {
x, err := jp.ParseString(path)
if err != nil {
return err
Expand All @@ -92,7 +95,7 @@ func (n *stringLookup) Match(ctx context.Context, input map[string]any) ([]*Stor
// Then, iterate through the inequality matches.
for item := range n.inequality {
path := item
eg.Go(func() error {
pool.Go(func() error {
x, err := jp.ParseString(path)
if err != nil {
return err
Expand All @@ -115,7 +118,7 @@ func (n *stringLookup) Match(ctx context.Context, input map[string]any) ([]*Stor
})
}

return matched, eg.Wait()
return matched, pool.Wait()
}

// Search returns all ExpressionParts which match the given input, ignoring the variable name
Expand Down
6 changes: 3 additions & 3 deletions engine_stringmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func TestEngineStringmap(t *testing.T) {
ctx := context.Background()
s := newStringEqualityMatcher().(*stringLookup)
s := newStringEqualityMatcher(testConcurrency).(*stringLookup)

a := ExpressionPart{
Predicate: &Predicate{
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestEngineStringmap(t *testing.T) {

func TestEngineStringmap_DuplicateValues(t *testing.T) {
ctx := context.Background()
s := newStringEqualityMatcher().(*stringLookup)
s := newStringEqualityMatcher(testConcurrency).(*stringLookup)
a := ExpressionPart{
Predicate: &Predicate{
Ident: "async.data.var_a",
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestEngineStringmap_DuplicateValues(t *testing.T) {

func TestEngineStringmap_DuplicateNeq(t *testing.T) {
ctx := context.Background()
s := newStringEqualityMatcher().(*stringLookup)
s := newStringEqualityMatcher(testConcurrency).(*stringLookup)
a := ExpressionPart{
Predicate: &Predicate{
Ident: "async.data.var_a",
Expand Down
52 changes: 22 additions & 30 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (

"github.com/google/cel-go/common/operators"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)

var (
Expand All @@ -19,6 +17,10 @@ var (
ErrExpressionPartNotFound = fmt.Errorf("expression part not found")
)

const (
defaultConcurrency = 1000
)

// errEngineUnimplemented is used while we develop the aggregate tree library when trees
// are not yet implemented.
var errEngineUnimplemented = fmt.Errorf("tree type unimplemented")
Expand Down Expand Up @@ -82,23 +84,23 @@ func NewAggregateEvaluator(
concurrency int64,
) AggregateEvaluator {
if concurrency <= 0 {
concurrency = 1
concurrency = defaultConcurrency
}

return &aggregator{
eval: eval,
parser: parser,
loader: evalLoader,
sem: semaphore.NewWeighted(concurrency),
engines: map[EngineType]MatchingEngine{
EngineTypeStringHash: newStringEqualityMatcher(),
EngineTypeNullMatch: newNullMatcher(),
EngineTypeBTree: newNumberMatcher(),
EngineTypeStringHash: newStringEqualityMatcher(concurrency),
EngineTypeNullMatch: newNullMatcher(concurrency),
EngineTypeBTree: newNumberMatcher(concurrency),
},
lock: &sync.RWMutex{},
evals: map[uuid.UUID]Evaluable{},
constants: map[uuid.UUID]struct{}{},
mixed: map[uuid.UUID]struct{}{},
lock: &sync.RWMutex{},
evals: map[uuid.UUID]Evaluable{},
constants: map[uuid.UUID]struct{}{},
mixed: map[uuid.UUID]struct{}{},
concurrency: concurrency,
}
}

Expand All @@ -110,8 +112,6 @@ type aggregator struct {
// engines records all engines
engines map[EngineType]MatchingEngine

sem *semaphore.Weighted

// lock prevents concurrent updates of data
lock *sync.RWMutex

Expand All @@ -131,6 +131,8 @@ type aggregator struct {
// constants tracks evaluable IDs that must always be evaluated, due to
// the expression containing non-aggregateable clauses.
constants map[uuid.UUID]struct{}

concurrency int64
}

// Len returns the total number of aggregateable and constantly matched expressions
Expand Down Expand Up @@ -171,7 +173,7 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
s sync.Mutex
)

eg := errgroup.Group{}
napool := newErrPool(errPoolOpts{concurrency: a.concurrency})

a.lock.RLock()
for uuid := range a.constants {
Expand All @@ -180,14 +182,8 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
continue
}

if err := a.sem.Acquire(ctx, 1); err != nil {
a.lock.RUnlock()
return result, matched, err
}

expr := item
eg.Go(func() error {
defer a.sem.Release(1)
napool.Go(func() error {
defer func() {
if r := recover(); r != nil {
s.Lock()
Expand Down Expand Up @@ -222,7 +218,7 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
}
a.lock.RUnlock()

if werr := eg.Wait(); werr != nil {
if werr := napool.Wait(); werr != nil {
err = errors.Join(err, werr)
}

Expand All @@ -238,20 +234,16 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
seenMu := &sync.Mutex{}
seen := map[uuid.UUID]struct{}{}

mpool := newErrPool(errPoolOpts{concurrency: a.concurrency})

a.lock.RLock()
for _, expr := range matches {
eval, ok := a.evals[expr.Parsed.EvaluableID]
if !ok || eval == nil {
continue
}

if err := a.sem.Acquire(ctx, 1); err != nil {
a.lock.RUnlock()
return result, matched, err
}

eg.Go(func() error {
defer a.sem.Release(1)
mpool.Go(func() error {
defer func() {
if r := recover(); r != nil {
s.Lock()
Expand Down Expand Up @@ -289,7 +281,7 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
}
a.lock.RUnlock()

if werr := eg.Wait(); werr != nil {
if werr := mpool.Wait(); werr != nil {
err = errors.Join(err, werr)
}

Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
module github.com/inngest/expr

go 1.21.0
go 1.23.2

require (
github.com/cespare/xxhash/v2 v2.2.0
github.com/google/cel-go v0.18.2
github.com/google/uuid v1.6.0
github.com/karlseguin/ccache/v2 v2.0.8
github.com/ohler55/ojg v1.21.0
github.com/sourcegraph/conc v0.3.0
github.com/stretchr/testify v1.8.4
github.com/tidwall/btree v1.7.0
golang.org/x/sync v0.6.0
google.golang.org/protobuf v1.33.0
)

require (
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stoewer/go-strcase v1.2.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect
golang.org/x/text v0.9.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 // indirect
Expand Down
Loading