Skip to content

Commit

Permalink
Add support for macros (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyhb authored May 21, 2024
1 parent 91339d9 commit 6f549d6
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 58 deletions.
87 changes: 41 additions & 46 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,36 +262,30 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) (bool, error) {
return false, err
}

if eval.GetExpression() == "" {
if eval.GetExpression() == "" || parsed.HasMacros {
// This is an empty expression which always matches.
a.lock.Lock()
a.constants = append(a.constants, parsed.EvaluableID)
a.lock.Unlock()
return false, nil
}

aggregateable := true
for _, g := range parsed.RootGroups() {
ok, err := a.iterGroup(ctx, g, parsed, a.addNode)
if err != nil {
return false, err
}

if !ok && aggregateable {
if err != nil || !ok {
// This is the first time we're seeing a non-aggregateable
// group, so add it to the constants list.
// group, so add it to the constants list and don't do anything else.
a.lock.Lock()
a.constants = append(a.constants, parsed.EvaluableID)
a.lock.Unlock()
aggregateable = false
return false, err
}
}

// Track the number of added expressions correctly.
if aggregateable {
atomic.AddInt32(&a.len, 1)
}
return aggregateable, nil
atomic.AddInt32(&a.len, 1)
return true, nil
}

func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error {
Expand Down Expand Up @@ -349,14 +343,14 @@ func (a *aggregator) removeConstantEvaluable(ctx context.Context, eval Evaluable
}

func (a *aggregator) iterGroup(ctx context.Context, node *Node, parsed *ParsedExpression, op nodeOp) (bool, error) {
// if len(node.Ors) > 0 {
// // If there are additional branches, don't bother to add this to the aggregate tree.
// // Mark this as a non-exhaustive addition and skip immediately.
// //
// // TODO: Allow ORs _only if_ the ORs are not nested, eg. the ORs are basic predicate
// // groups that themselves have no branches.
// return false, nil
// }
if len(node.Ors) > 0 {
// If there are additional branches, don't bother to add this to the aggregate tree.
// Mark this as a non-exhaustive addition and skip immediately.
//
// TODO: Allow ORs _only if_ the ORs are not nested, eg. the ORs are basic predicate
// groups that themselves have no branches.
return false, nil
}

if len(node.Ands) > 0 {
for _, n := range node.Ands {
Expand Down Expand Up @@ -436,56 +430,54 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpress
if n.Predicate == nil {
return nil
}
e := a.engine(n)
if e == nil {
return errEngineUnimplemented
}

// Don't allow anything to update in parallel. This ensures that Add() can be called
// concurrently.
a.lock.Lock()
defer a.lock.Unlock()

requiredEngine := engineType(*n.Predicate)

if requiredEngine == EngineTypeNone {
return errEngineUnimplemented
}

for _, engine := range a.engines {
if engine.Type() != requiredEngine {
continue
}
return engine.Add(ctx, ExpressionPart{
GroupID: n.GroupID,
Predicate: n.Predicate,
Parsed: parsed,
})
}
return errEngineUnimplemented
return e.Add(ctx, ExpressionPart{
GroupID: n.GroupID,
Predicate: n.Predicate,
Parsed: parsed,
})
}

func (a *aggregator) removeNode(ctx context.Context, n *Node, parsed *ParsedExpression) error {
if n.Predicate == nil {
return nil
}
e := a.engine(n)
if e == nil {
return errEngineUnimplemented
}

// Don't allow anything to update in parallel. This enrues that Add() can be called
// concurrently.
a.lock.Lock()
defer a.lock.Unlock()
return e.Remove(ctx, ExpressionPart{
GroupID: n.GroupID,
Predicate: n.Predicate,
Parsed: parsed,
})
}

func (a *aggregator) engine(n *Node) MatchingEngine {
requiredEngine := engineType(*n.Predicate)
if requiredEngine == EngineTypeNone {
return errEngineUnimplemented
return nil
}
for _, engine := range a.engines {
if engine.Type() != requiredEngine {
continue
}
return engine.Remove(ctx, ExpressionPart{
GroupID: n.GroupID,
Predicate: n.Predicate,
Parsed: parsed,
})
return engine
}
return errEngineUnimplemented
return nil
}

func isAggregateable(n *Node) bool {
Expand All @@ -499,6 +491,10 @@ func isAggregateable(n *Node) bool {
return false
}

if n.Predicate.Operator == "comprehension" {
return false
}

switch v := n.Predicate.Literal.(type) {
case string:
if len(v) == 0 {
Expand All @@ -511,7 +507,6 @@ func isAggregateable(n *Node) bool {
return false
}
// Right now, we only support equality checking.
//
// TODO: Add GT(e)/LT(e) matching with tree iteration.
return n.Predicate.Operator == operators.Equals
case int, int64, float64:
Expand Down
62 changes: 62 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ func TestEvaluate_Compound(t *testing.T) {
require.EqualValues(t, []Evaluable{expected}, evals)
})

// Note: we do not use group IDs for optimization right now.
//
// t.Run("It skips if less than the group length is found", func(t *testing.T) {
// evals, matched, err := e.Evaluate(ctx, map[string]any{
// "event": map[string]any{
Expand Down Expand Up @@ -455,6 +457,66 @@ func TestAggregateMatch(t *testing.T) {
})
}

func TestMacros(t *testing.T) {
ctx := context.Background()
parser, err := newParser()
require.NoError(t, err)

loader := newEvalLoader()
e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load)
eval := tex(`event.data.ok == "true" || event.data.ids.exists(id, id == 'c')`)
loader.AddEval(eval)
ok, err := e.Add(ctx, eval)
require.NoError(t, err)
require.False(t, ok)

t.Run("It doesn't evaluate macros", func(t *testing.T) {

input := map[string]any{
"event": map[string]any{
"data": map[string]any{
"ok": nil,
"ids": []string{"a", "b", "c"},
},
},
}
evals, matched, err := e.Evaluate(ctx, input)
require.NoError(t, err)
require.EqualValues(t, 1, len(evals))
require.EqualValues(t, 1, matched)

t.Run("Failing match", func(t *testing.T) {
input = map[string]any{
"event": map[string]any{
"data": map[string]any{
"ok": nil,
"ids": []string{"nope"},
},
},
}
evals, matched, err = e.Evaluate(ctx, input)
require.NoError(t, err)
require.EqualValues(t, 0, len(evals))
require.EqualValues(t, 1, matched)
})

t.Run("Partial macro", func(t *testing.T) {
input = map[string]any{
"event": map[string]any{
"data": map[string]any{
"ok": "true",
"ids": []string{"nope"},
},
},
}
evals, matched, err = e.Evaluate(ctx, input)
require.NoError(t, err)
require.EqualValues(t, 1, len(evals), evals)
require.EqualValues(t, 1, matched)
})
})
}

func TestAddRemove(t *testing.T) {
ctx := context.Background()
parser, err := newParser()
Expand Down
37 changes: 25 additions & 12 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression,
}

node := newNode()
_, err := navigateAST(
_, hasMacros, err := navigateAST(
expr{
ast: ast.NativeRep().Expr(),
},
Expand All @@ -118,6 +118,7 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression,
Root: *node,
Vars: vars,
EvaluableID: eval.GetID(),
HasMacros: hasMacros,
}, nil
}

Expand All @@ -139,6 +140,8 @@ type ParsedExpression struct {

// Evaluable stores the original evaluable interface that was parsed.
EvaluableID uuid.UUID

HasMacros bool
}

// RootGroups returns the top-level matching groups within an expression. This is a small
Expand Down Expand Up @@ -359,28 +362,37 @@ type expr struct {
// It does this by iterating through the expression, amending the current `group` until
// an or expression is found. When an or expression is found, we create another group which
// is mutated by the iteration.
func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]*Node, error) {
func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]*Node, bool, error) {
// on the very first call to navigateAST, ensure that we set the first node
// inside the nodemap.
result := []*Node{}

hasMacros := false

// Iterate through the stack, recursing down into each function call (eg. && branches).
stack := []expr{nav}
for len(stack) > 0 {
item := stack[0]
stack = stack[1:]

switch item.ast.Kind() {
case celast.ComprehensionKind:
// These are not supported. A comprehension is eg. `.exists` and must
// awlays run naively right now.
c := item.ast.AsComprehension()
child := &Node{
Predicate: &Predicate{
Ident: c.IterVar(),
Operator: "comprehension",
},
}
child.normalize()
result = append(result, child)
hasMacros = true
case celast.LiteralKind:
// This is a literal. Do nothing, as this is always true.
case celast.IdentKind:
// This is a variable. DO nothing.
// predicate := Predicate{
// Literal: true,
// Ident: item.AsIdent(),
// Operator: operators.Equals,
// }
// current.Predicates = append(current.Predicates, predicate)
// This is a variable. Do nothing.
case celast.CallKind:
// Call kinds are the actual comparator operators, eg. >=, or &&. These are specifically
// what we're trying to parse, by taking the LHS and RHS of each opeartor then bringing
Expand All @@ -403,14 +415,15 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]

if fn == operators.LogicalOr {
for _, or := range peek(item, operators.LogicalOr) {
var err error
// Ors modify new nodes. Assign a new Node to each
// Or entry.
newParent := newNode()

// For each item in the stack, recurse into that AST.
_, err := navigateAST(or, newParent, vars, rand)
_, hasMacros, err = navigateAST(or, newParent, vars, rand)
if err != nil {
return nil, err
return nil, hasMacros, err
}

// Ensure that we remove any redundant parents generated.
Expand Down Expand Up @@ -480,7 +493,7 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]
}
}

return result, nil
return result, hasMacros, nil
}

// peek recurses through nested operators (eg. a && b && c), grouping all operators
Expand Down

0 comments on commit 6f549d6

Please sign in to comment.