From 6f549d6745fd59e4d7cffc2e6f1deeb125750707 Mon Sep 17 00:00:00 2001 From: Tony Holdstock-Brown Date: Tue, 21 May 2024 08:33:20 -0700 Subject: [PATCH] Add support for macros (#19) --- expr.go | 87 +++++++++++++++++++++++++--------------------------- expr_test.go | 62 +++++++++++++++++++++++++++++++++++++ parser.go | 37 ++++++++++++++-------- 3 files changed, 128 insertions(+), 58 deletions(-) diff --git a/expr.go b/expr.go index 99becd3..2540488 100644 --- a/expr.go +++ b/expr.go @@ -262,7 +262,7 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) (bool, error) { return false, err } - if eval.GetExpression() == "" { + if eval.GetExpression() == "" || parsed.HasMacros { // This is an empty expression which always matches. a.lock.Lock() a.constants = append(a.constants, parsed.EvaluableID) @@ -270,28 +270,22 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) (bool, error) { return false, nil } - aggregateable := true for _, g := range parsed.RootGroups() { ok, err := a.iterGroup(ctx, g, parsed, a.addNode) - if err != nil { - return false, err - } - if !ok && aggregateable { + if err != nil || !ok { // This is the first time we're seeing a non-aggregateable - // group, so add it to the constants list. + // group, so add it to the constants list and don't do anything else. a.lock.Lock() a.constants = append(a.constants, parsed.EvaluableID) a.lock.Unlock() - aggregateable = false + return false, err } } // Track the number of added expressions correctly. - if aggregateable { - atomic.AddInt32(&a.len, 1) - } - return aggregateable, nil + atomic.AddInt32(&a.len, 1) + return true, nil } func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error { @@ -349,14 +343,14 @@ func (a *aggregator) removeConstantEvaluable(ctx context.Context, eval Evaluable } func (a *aggregator) iterGroup(ctx context.Context, node *Node, parsed *ParsedExpression, op nodeOp) (bool, error) { - // if len(node.Ors) > 0 { - // // If there are additional branches, don't bother to add this to the aggregate tree. - // // Mark this as a non-exhaustive addition and skip immediately. - // // - // // TODO: Allow ORs _only if_ the ORs are not nested, eg. the ORs are basic predicate - // // groups that themselves have no branches. - // return false, nil - // } + if len(node.Ors) > 0 { + // If there are additional branches, don't bother to add this to the aggregate tree. + // Mark this as a non-exhaustive addition and skip immediately. + // + // TODO: Allow ORs _only if_ the ORs are not nested, eg. the ORs are basic predicate + // groups that themselves have no branches. + return false, nil + } if len(node.Ands) > 0 { for _, n := range node.Ands { @@ -436,56 +430,54 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpress if n.Predicate == nil { return nil } + e := a.engine(n) + if e == nil { + return errEngineUnimplemented + } // Don't allow anything to update in parallel. This ensures that Add() can be called // concurrently. a.lock.Lock() defer a.lock.Unlock() - - requiredEngine := engineType(*n.Predicate) - - if requiredEngine == EngineTypeNone { - return errEngineUnimplemented - } - - for _, engine := range a.engines { - if engine.Type() != requiredEngine { - continue - } - return engine.Add(ctx, ExpressionPart{ - GroupID: n.GroupID, - Predicate: n.Predicate, - Parsed: parsed, - }) - } - return errEngineUnimplemented + return e.Add(ctx, ExpressionPart{ + GroupID: n.GroupID, + Predicate: n.Predicate, + Parsed: parsed, + }) } func (a *aggregator) removeNode(ctx context.Context, n *Node, parsed *ParsedExpression) error { if n.Predicate == nil { return nil } + e := a.engine(n) + if e == nil { + return errEngineUnimplemented + } // Don't allow anything to update in parallel. This enrues that Add() can be called // concurrently. a.lock.Lock() defer a.lock.Unlock() + return e.Remove(ctx, ExpressionPart{ + GroupID: n.GroupID, + Predicate: n.Predicate, + Parsed: parsed, + }) +} +func (a *aggregator) engine(n *Node) MatchingEngine { requiredEngine := engineType(*n.Predicate) if requiredEngine == EngineTypeNone { - return errEngineUnimplemented + return nil } for _, engine := range a.engines { if engine.Type() != requiredEngine { continue } - return engine.Remove(ctx, ExpressionPart{ - GroupID: n.GroupID, - Predicate: n.Predicate, - Parsed: parsed, - }) + return engine } - return errEngineUnimplemented + return nil } func isAggregateable(n *Node) bool { @@ -499,6 +491,10 @@ func isAggregateable(n *Node) bool { return false } + if n.Predicate.Operator == "comprehension" { + return false + } + switch v := n.Predicate.Literal.(type) { case string: if len(v) == 0 { @@ -511,7 +507,6 @@ func isAggregateable(n *Node) bool { return false } // Right now, we only support equality checking. - // // TODO: Add GT(e)/LT(e) matching with tree iteration. return n.Predicate.Operator == operators.Equals case int, int64, float64: diff --git a/expr_test.go b/expr_test.go index 6ec4c1a..62f5c0e 100644 --- a/expr_test.go +++ b/expr_test.go @@ -358,6 +358,8 @@ func TestEvaluate_Compound(t *testing.T) { require.EqualValues(t, []Evaluable{expected}, evals) }) + // Note: we do not use group IDs for optimization right now. + // // t.Run("It skips if less than the group length is found", func(t *testing.T) { // evals, matched, err := e.Evaluate(ctx, map[string]any{ // "event": map[string]any{ @@ -455,6 +457,66 @@ func TestAggregateMatch(t *testing.T) { }) } +func TestMacros(t *testing.T) { + ctx := context.Background() + parser, err := newParser() + require.NoError(t, err) + + loader := newEvalLoader() + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + eval := tex(`event.data.ok == "true" || event.data.ids.exists(id, id == 'c')`) + loader.AddEval(eval) + ok, err := e.Add(ctx, eval) + require.NoError(t, err) + require.False(t, ok) + + t.Run("It doesn't evaluate macros", func(t *testing.T) { + + input := map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "ok": nil, + "ids": []string{"a", "b", "c"}, + }, + }, + } + evals, matched, err := e.Evaluate(ctx, input) + require.NoError(t, err) + require.EqualValues(t, 1, len(evals)) + require.EqualValues(t, 1, matched) + + t.Run("Failing match", func(t *testing.T) { + input = map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "ok": nil, + "ids": []string{"nope"}, + }, + }, + } + evals, matched, err = e.Evaluate(ctx, input) + require.NoError(t, err) + require.EqualValues(t, 0, len(evals)) + require.EqualValues(t, 1, matched) + }) + + t.Run("Partial macro", func(t *testing.T) { + input = map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "ok": "true", + "ids": []string{"nope"}, + }, + }, + } + evals, matched, err = e.Evaluate(ctx, input) + require.NoError(t, err) + require.EqualValues(t, 1, len(evals), evals) + require.EqualValues(t, 1, matched) + }) + }) +} + func TestAddRemove(t *testing.T) { ctx := context.Background() parser, err := newParser() diff --git a/parser.go b/parser.go index 733bb74..299fa81 100644 --- a/parser.go +++ b/parser.go @@ -101,7 +101,7 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, } node := newNode() - _, err := navigateAST( + _, hasMacros, err := navigateAST( expr{ ast: ast.NativeRep().Expr(), }, @@ -118,6 +118,7 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, Root: *node, Vars: vars, EvaluableID: eval.GetID(), + HasMacros: hasMacros, }, nil } @@ -139,6 +140,8 @@ type ParsedExpression struct { // Evaluable stores the original evaluable interface that was parsed. EvaluableID uuid.UUID + + HasMacros bool } // RootGroups returns the top-level matching groups within an expression. This is a small @@ -359,11 +362,13 @@ type expr struct { // It does this by iterating through the expression, amending the current `group` until // an or expression is found. When an or expression is found, we create another group which // is mutated by the iteration. -func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]*Node, error) { +func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]*Node, bool, error) { // on the very first call to navigateAST, ensure that we set the first node // inside the nodemap. result := []*Node{} + hasMacros := false + // Iterate through the stack, recursing down into each function call (eg. && branches). stack := []expr{nav} for len(stack) > 0 { @@ -371,16 +376,23 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([] stack = stack[1:] switch item.ast.Kind() { + case celast.ComprehensionKind: + // These are not supported. A comprehension is eg. `.exists` and must + // awlays run naively right now. + c := item.ast.AsComprehension() + child := &Node{ + Predicate: &Predicate{ + Ident: c.IterVar(), + Operator: "comprehension", + }, + } + child.normalize() + result = append(result, child) + hasMacros = true case celast.LiteralKind: // This is a literal. Do nothing, as this is always true. case celast.IdentKind: - // This is a variable. DO nothing. - // predicate := Predicate{ - // Literal: true, - // Ident: item.AsIdent(), - // Operator: operators.Equals, - // } - // current.Predicates = append(current.Predicates, predicate) + // This is a variable. Do nothing. case celast.CallKind: // Call kinds are the actual comparator operators, eg. >=, or &&. These are specifically // what we're trying to parse, by taking the LHS and RHS of each opeartor then bringing @@ -403,14 +415,15 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([] if fn == operators.LogicalOr { for _, or := range peek(item, operators.LogicalOr) { + var err error // Ors modify new nodes. Assign a new Node to each // Or entry. newParent := newNode() // For each item in the stack, recurse into that AST. - _, err := navigateAST(or, newParent, vars, rand) + _, hasMacros, err = navigateAST(or, newParent, vars, rand) if err != nil { - return nil, err + return nil, hasMacros, err } // Ensure that we remove any redundant parents generated. @@ -480,7 +493,7 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([] } } - return result, nil + return result, hasMacros, nil } // peek recurses through nested operators (eg. a && b && c), grouping all operators