Skip to content

Commit

Permalink
Merge pull request #8 from inngest/chore/refactor-group-ids
Browse files Browse the repository at this point in the history
Refactor group IDs and utilize group IDs to filter invalid expressions prior to evaluation
  • Loading branch information
tonyhb authored Jan 6, 2024
2 parents 921f6fc + 3014841 commit fbd7922
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 48 deletions.
76 changes: 59 additions & 17 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ func NewAggregateEvaluator(
}

type Evaluable interface {
// Identifier returns a unique identifier for the evaluable item. If there are
// two instances of the same expression, the identifier should return a unique
// string for each instance of the expression (eg. for two pauses).
Identifier() string

// Expression returns an expression as a raw string.
Expression() string
}
Expand Down Expand Up @@ -137,12 +142,21 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
// are added (eg. >= operators on strings), ensure that we find the correct number of matches
// for each group ID and then skip evaluating expressions if the number of matches is <= the group
// ID's length.
seen := map[groupID]struct{}{}

for _, match := range matches {
if _, ok := seen[match.GroupID]; ok {
continue
}

atomic.AddInt32(&matched, 1)
// NOTE: We don't need to add lifted expression variables,
// because match.Parsed.Evaluable() returns the original expression
// string.
ok, evalerr := a.eval(ctx, match.Parsed.Evaluable, data)

seen[match.GroupID] = struct{}{}

if evalerr != nil {
err = errors.Join(err, evalerr)
continue
Expand All @@ -161,6 +175,12 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([
a.lock.RLock()
defer a.lock.RUnlock()

// Store the number of times each GroupID has found a match. We need at least
// as many matches as stored in the group ID to consider the match.
counts := map[groupID]int{}
// Store all expression parts per group ID for returning.
found := map[groupID][]ExpressionPart{}

// Iterate through all known variables/idents in the aggregate tree to see if
// the data has those keys set. If so, we can immediately evaluate the data with
// the tree.
Expand All @@ -179,16 +199,32 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([

switch cast := res[0].(type) {
case string:
found, ok := tree.Search(ctx, cast)
all, ok := tree.Search(ctx, cast)
if !ok {
continue
}
result = append(result, found.Evals...)

for _, eval := range all.Evals {
counts[eval.GroupID] += 1
if _, ok := found[eval.GroupID]; !ok {
found[eval.GroupID] = []ExpressionPart{}
}
found[eval.GroupID] = append(found[eval.GroupID], eval)
}
default:
continue
}
}

for k, count := range counts {
if int(k.Size()) > count {
// The GroupID required more comparisons to equate to true than
// we had, so this could never evaluate to true. Skip this.
continue
}
result = append(result, found[k]...)
}

return result, nil
}

Expand Down Expand Up @@ -238,16 +274,26 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp
return false, nil
}

// Merge all of the nodes together and check whether each node is aggregateable.
all := append(node.Ands, node)
for _, n := range all {
if !n.HasPredicate() || len(n.Ors) > 0 {
// Don't handle sub-branching for now.
return false, nil
if len(node.Ands) > 0 {
for _, n := range node.Ands {
if !n.HasPredicate() || len(n.Ors) > 0 {
// Don't handle sub-branching for now.
return false, nil
}
if !isAggregateable(n) {
return false, nil
}
}
if !isAggregateable(n) {
}

all := node.Ands

if node.Predicate != nil {
if !isAggregateable(node) {
return false, nil
}
// Merge all of the nodes together and check whether each node is aggregateable.
all = append(node.Ands, node)
}

// Create a new group ID which tracks the number of expressions that must match
Expand All @@ -258,9 +304,8 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp
// When checking an incoming event, we match the event against each node's
// ident/variable. Using the group ID, we can see if we've matched N necessary
// items from the same identifier. If so, the evaluation is true.
groupID := newGroupID(uint16(len(all)))
for _, n := range all {
err := a.addNode(ctx, n, groupID, parsed)
err := a.addNode(ctx, n, parsed)
if err == errTreeUnimplemented {
return false, nil
}
Expand All @@ -272,7 +317,7 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp
return true, nil
}

func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed *ParsedExpression) error {
func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpression) error {
// Don't allow anything to update in parallel. This enrues that Add() can be called
// concurrently.
a.lock.Lock()
Expand All @@ -286,7 +331,7 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed *
tree = newArtTree()
}
err := tree.Add(ctx, ExpressionPart{
GroupID: gid,
GroupID: n.GroupID,
Predicate: *n.Predicate,
Parsed: parsed,
})
Expand All @@ -302,14 +347,11 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed *
func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error {
// parse the expression using our tree parser.
parsed, err := a.parser.Parse(ctx, eval)
_ = parsed
if err != nil {
return err
}

for _, g := range parsed.RootGroups() {
_ = g
}

return fmt.Errorf("not implemented")
}

Expand Down
70 changes: 48 additions & 22 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,7 @@ func TestEvaluate(t *testing.T) {

require.NoError(t, err)
require.EqualValues(t, 0, len(evals))
require.EqualValues(t, 1, matched) // We still ran one expression
})

t.Run("It handles matching on arrays of data", func(t *testing.T) {
pre := time.Now()
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"ids": []string{"a", "b", "c"},
},
},
})
total := time.Since(pre)
fmt.Printf("Matched in %v ns\n", total.Nanoseconds())
fmt.Printf("Matched in %v ms\n", total.Milliseconds())

require.NoError(t, err)
require.EqualValues(t, 0, len(evals))
require.EqualValues(t, 1, matched) // We still ran one expression
require.EqualValues(t, 0, matched) // We still ran one expression
})
}

Expand All @@ -170,7 +152,7 @@ func TestEvaluate_Concurrently(t *testing.T) {
require.NoError(t, err)

go func() {
for i := 0; i < 1_000; i++ {
for i := 0; i < 100_000; i++ {
//nolint:all
go func() {
byt := make([]byte, 8)
Expand Down Expand Up @@ -213,7 +195,7 @@ func TestEvaluate_ArrayIndexes(t *testing.T) {
require.NoError(t, err)
e := NewAggregateEvaluator(parser, testBoolEvaluator)

expected := tex(`event.data.ids[2] == "id-b"`)
expected := tex(`event.data.ids[1] == "id-b" && event.data.ids[2] == "id-c"`)
_, err = e.Add(ctx, expected)
require.NoError(t, err)

Expand Down Expand Up @@ -258,7 +240,7 @@ func TestEvaluate_ArrayIndexes(t *testing.T) {
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"ids": []string{"a", "yes", "id-b"},
"ids": []string{"id-a", "id-b", "id-c"},
},
},
})
Expand All @@ -272,6 +254,49 @@ func TestEvaluate_ArrayIndexes(t *testing.T) {
})
}

func TestEvaluate_Compound(t *testing.T) {
ctx := context.Background()
parser, err := NewTreeParser(NewCachingParser(newEnv(), nil))
require.NoError(t, err)
e := NewAggregateEvaluator(parser, testBoolEvaluator)

expected := tex(`event.data.a == "ok" && event.data.b == "yes" && event.data.c == "please"`)
ok, err := e.Add(ctx, expected)
require.True(t, ok)
require.NoError(t, err)

t.Run("It matches items", func(t *testing.T) {
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"a": "ok",
"b": "yes",
"c": "please",
},
},
})
require.NoError(t, err)
require.EqualValues(t, 1, matched) // We only perform one eval
require.EqualValues(t, []Evaluable{expected}, evals)
})

t.Run("It skips if less than the group length is found", func(t *testing.T) {
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"a": "ok",
"b": "yes",
"c": "no - no match",
},
},
})
require.NoError(t, err)
require.EqualValues(t, 0, matched)
require.EqualValues(t, []Evaluable{}, evals)
})

}

func TestAggregateMatch(t *testing.T) {
ctx := context.Background()
parser, err := newParser()
Expand Down Expand Up @@ -402,6 +427,7 @@ func TestAdd(t *testing.T) {
type tex string

func (e tex) Expression() string { return string(e) }
func (e tex) Identifier() string { return string(e) }

func testBoolEvaluator(ctx context.Context, e Evaluable, input map[string]any) (bool, error) {
env, _ := cel.NewEnv(
Expand Down
15 changes: 14 additions & 1 deletion groupid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,33 @@ package expr
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
)

// groupID represents a group ID. The first 2 byets are an int16 size of the expression group,
// representing the number of predicates within the expression. The last 6 bytes are a random
// ID for the predicate group.
type groupID [8]byte

var rander = rand.Read

type RandomReader func(p []byte) (n int, err error)

func (g groupID) String() string {
return hex.EncodeToString(g[:])
}

func (g groupID) Size() uint16 {
return binary.NativeEndian.Uint16(g[0:2])
}

func newGroupID(size uint16) groupID {
return newGroupIDWithReader(size, rander)
}

func newGroupIDWithReader(size uint16, rander RandomReader) groupID {
id := make([]byte, 8)
binary.NativeEndian.PutUint16(id, size)
_, _ = rand.Read(id[2:])
_, _ = rander(id[2:])
return [8]byte(id[0:8])
}
Loading

0 comments on commit fbd7922

Please sign in to comment.