From db406123c6bc27190e94fa471b3f8d3b261815ed Mon Sep 17 00:00:00 2001 From: Tony Holdstock-Brown Date: Thu, 4 Jan 2024 21:22:03 -0800 Subject: [PATCH 1/2] Refactor group IDs to set at the parse level This ensures group IDs are static for each parsed node. We also ensure that we check for at least GroupID.Size() matching items from trees when matching on incoming events. --- expr.go | 71 +++++++++++++++++++++++-------- expr_test.go | 69 ++++++++++++++++++++---------- groupid.go | 9 +++- parser.go | 27 ++++++++++++ parser_test.go | 112 ++++++++++++++++++++++++++++++++++++++++++++++++- tree.go | 11 +++++ tree_art.go | 33 ++++++++++++++- 7 files changed, 290 insertions(+), 42 deletions(-) diff --git a/expr.go b/expr.go index 7d17f61..18d7150 100644 --- a/expr.go +++ b/expr.go @@ -137,12 +137,21 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu // are added (eg. >= operators on strings), ensure that we find the correct number of matches // for each group ID and then skip evaluating expressions if the number of matches is <= the group // ID's length. + seen := map[groupID]struct{}{} + for _, match := range matches { + if _, ok := seen[match.GroupID]; ok { + continue + } + atomic.AddInt32(&matched, 1) // NOTE: We don't need to add lifted expression variables, // because match.Parsed.Evaluable() returns the original expression // string. ok, evalerr := a.eval(ctx, match.Parsed.Evaluable, data) + + seen[match.GroupID] = struct{}{} + if evalerr != nil { err = errors.Join(err, evalerr) continue @@ -161,6 +170,12 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([ a.lock.RLock() defer a.lock.RUnlock() + // Store the number of times each GroupID has found a match. We need at least + // as many matches as stored in the group ID to consider the match. + counts := map[groupID]int{} + // Store all expression parts per group ID for returning. + found := map[groupID][]ExpressionPart{} + // Iterate through all known variables/idents in the aggregate tree to see if // the data has those keys set. If so, we can immediately evaluate the data with // the tree. @@ -179,16 +194,32 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([ switch cast := res[0].(type) { case string: - found, ok := tree.Search(ctx, cast) + all, ok := tree.Search(ctx, cast) if !ok { continue } - result = append(result, found.Evals...) + + for _, eval := range all.Evals { + counts[eval.GroupID] += 1 + if _, ok := found[eval.GroupID]; !ok { + found[eval.GroupID] = []ExpressionPart{} + } + found[eval.GroupID] = append(found[eval.GroupID], eval) + } default: continue } } + for k, count := range counts { + if int(k.Size()) > count { + // The GroupID required more comparisons to equate to true than + // we had, so this could never evaluate to true. Skip this. + continue + } + result = append(result, found[k]...) + } + return result, nil } @@ -238,16 +269,26 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp return false, nil } - // Merge all of the nodes together and check whether each node is aggregateable. - all := append(node.Ands, node) - for _, n := range all { - if !n.HasPredicate() || len(n.Ors) > 0 { - // Don't handle sub-branching for now. - return false, nil + if len(node.Ands) > 0 { + for _, n := range node.Ands { + if !n.HasPredicate() || len(n.Ors) > 0 { + // Don't handle sub-branching for now. + return false, nil + } + if !isAggregateable(n) { + return false, nil + } } - if !isAggregateable(n) { + } + + all := node.Ands + + if node.Predicate != nil { + if !isAggregateable(node) { return false, nil } + // Merge all of the nodes together and check whether each node is aggregateable. + all = append(node.Ands, node) } // Create a new group ID which tracks the number of expressions that must match @@ -258,9 +299,8 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp // When checking an incoming event, we match the event against each node's // ident/variable. Using the group ID, we can see if we've matched N necessary // items from the same identifier. If so, the evaluation is true. - groupID := newGroupID(uint16(len(all))) for _, n := range all { - err := a.addNode(ctx, n, groupID, parsed) + err := a.addNode(ctx, n, parsed) if err == errTreeUnimplemented { return false, nil } @@ -272,7 +312,7 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp return true, nil } -func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed *ParsedExpression) error { +func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpression) error { // Don't allow anything to update in parallel. This enrues that Add() can be called // concurrently. a.lock.Lock() @@ -286,7 +326,7 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed * tree = newArtTree() } err := tree.Add(ctx, ExpressionPart{ - GroupID: gid, + GroupID: n.GroupID, Predicate: *n.Predicate, Parsed: parsed, }) @@ -302,14 +342,11 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed * func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error { // parse the expression using our tree parser. parsed, err := a.parser.Parse(ctx, eval) + _ = parsed if err != nil { return err } - for _, g := range parsed.RootGroups() { - _ = g - } - return fmt.Errorf("not implemented") } diff --git a/expr_test.go b/expr_test.go index 3bea931..67fa282 100644 --- a/expr_test.go +++ b/expr_test.go @@ -137,25 +137,7 @@ func TestEvaluate(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 0, len(evals)) - require.EqualValues(t, 1, matched) // We still ran one expression - }) - - t.Run("It handles matching on arrays of data", func(t *testing.T) { - pre := time.Now() - evals, matched, err := e.Evaluate(ctx, map[string]any{ - "event": map[string]any{ - "data": map[string]any{ - "ids": []string{"a", "b", "c"}, - }, - }, - }) - total := time.Since(pre) - fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) - fmt.Printf("Matched in %v ms\n", total.Milliseconds()) - - require.NoError(t, err) - require.EqualValues(t, 0, len(evals)) - require.EqualValues(t, 1, matched) // We still ran one expression + require.EqualValues(t, 0, matched) // We still ran one expression }) } @@ -170,7 +152,7 @@ func TestEvaluate_Concurrently(t *testing.T) { require.NoError(t, err) go func() { - for i := 0; i < 1_000; i++ { + for i := 0; i < 100_000; i++ { //nolint:all go func() { byt := make([]byte, 8) @@ -213,7 +195,7 @@ func TestEvaluate_ArrayIndexes(t *testing.T) { require.NoError(t, err) e := NewAggregateEvaluator(parser, testBoolEvaluator) - expected := tex(`event.data.ids[2] == "id-b"`) + expected := tex(`event.data.ids[1] == "id-b" && event.data.ids[2] == "id-c"`) _, err = e.Add(ctx, expected) require.NoError(t, err) @@ -258,7 +240,7 @@ func TestEvaluate_ArrayIndexes(t *testing.T) { evals, matched, err := e.Evaluate(ctx, map[string]any{ "event": map[string]any{ "data": map[string]any{ - "ids": []string{"a", "yes", "id-b"}, + "ids": []string{"id-a", "id-b", "id-c"}, }, }, }) @@ -272,6 +254,49 @@ func TestEvaluate_ArrayIndexes(t *testing.T) { }) } +func TestEvaluate_Compound(t *testing.T) { + ctx := context.Background() + parser, err := NewTreeParser(NewCachingParser(newEnv(), nil)) + require.NoError(t, err) + e := NewAggregateEvaluator(parser, testBoolEvaluator) + + expected := tex(`event.data.a == "ok" && event.data.b == "yes" && event.data.c == "please"`) + ok, err := e.Add(ctx, expected) + require.True(t, ok) + require.NoError(t, err) + + t.Run("It matches items", func(t *testing.T) { + evals, matched, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "a": "ok", + "b": "yes", + "c": "please", + }, + }, + }) + require.NoError(t, err) + require.EqualValues(t, 1, matched) // We only perform one eval + require.EqualValues(t, []Evaluable{expected}, evals) + }) + + t.Run("It skips if less than the group length is found", func(t *testing.T) { + evals, matched, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "a": "ok", + "b": "yes", + "c": "no - no match", + }, + }, + }) + require.NoError(t, err) + require.EqualValues(t, 0, matched) + require.EqualValues(t, []Evaluable{}, evals) + }) + +} + func TestAggregateMatch(t *testing.T) { ctx := context.Background() parser, err := newParser() diff --git a/groupid.go b/groupid.go index ea33e26..b4591bd 100644 --- a/groupid.go +++ b/groupid.go @@ -3,6 +3,7 @@ package expr import ( "crypto/rand" "encoding/binary" + "encoding/hex" ) // groupID represents a group ID. The first 2 byets are an int16 size of the expression group, @@ -10,6 +11,12 @@ import ( // ID for the predicate group. type groupID [8]byte +var rander = rand.Read + +func (g groupID) String() string { + return hex.EncodeToString(g[:]) +} + func (g groupID) Size() uint16 { return binary.NativeEndian.Uint16(g[0:2]) } @@ -17,6 +24,6 @@ func (g groupID) Size() uint16 { func newGroupID(size uint16) groupID { id := make([]byte, 8) binary.NativeEndian.PutUint16(id, size) - _, _ = rand.Read(id[2:]) + _, _ = rander(id[2:]) return [8]byte(id[0:8]) } diff --git a/parser.go b/parser.go index 1220601..286b308 100644 --- a/parser.go +++ b/parser.go @@ -136,6 +136,8 @@ func (p ParsedExpression) RootGroups() []*Node { // This requres A *and* either B or C, and so we require all ANDs plus at least one node // from OR to evaluate to true type Node struct { + GroupID groupID + // Ands contains predicates at this level of the expression that are joined together // with an && operator. All nodes in this set must evaluate to true in order for this // node in the expression to be truthy. @@ -416,6 +418,31 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) { } parent.Ands = result + + // Add a group ID to the parent. + total := len(parent.Ands) + if parent.Predicate != nil { + total += 1 + } + if len(parent.Ors) >= 1 { + total += 1 + } + + parent.GroupID = newGroupID(uint16(total)) + // For each sub-group, add the same group IDs to children if there's no nesting. + for n, item := range parent.Ands { + if len(item.Ands) == 0 && len(item.Ors) == 0 && item.Predicate != nil { + item.GroupID = parent.GroupID + parent.Ands[n] = item + } + } + for n, item := range parent.Ors { + if len(item.Ands) == 0 && len(item.Ors) == 0 && item.Predicate != nil { + item.GroupID = parent.GroupID + parent.Ors[n] = item + } + } + return result, nil } diff --git a/parser_test.go b/parser_test.go index cf9fd76..c320044 100644 --- a/parser_test.go +++ b/parser_test.go @@ -32,6 +32,14 @@ type parseTestInput struct { func TestParse(t *testing.T) { ctx := context.Background() + origRander := rander + rander = func(b []byte) (n int, err error) { + return 0, nil + } + t.Cleanup(func() { + rander = origRander + }) + // helper function to assert each case. assert := func(t *testing.T, tests []parseTestInput) { t.Helper() @@ -43,6 +51,8 @@ func TestParse(t *testing.T) { eval := tex(test.input) actual, err := parser.Parse(ctx, eval) + require.NotNil(t, actual.Root.GroupID) + // Shortcut to ensure the evaluable instance matches if test.expected.Evaluable == nil { test.expected.Evaluable = eval @@ -76,6 +86,7 @@ func TestParse(t *testing.T) { output: `event.data.ids[2] == "a"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Ident: "event.data.ids[2]", Literal: "a", @@ -89,6 +100,7 @@ func TestParse(t *testing.T) { output: `event.data.ids[2].id == "a"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Ident: "event.data.ids[2].id", Literal: "a", @@ -111,6 +123,7 @@ func TestParse(t *testing.T) { output: `event == vars.a`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Ident: "event", LiteralIdent: &ident, @@ -131,6 +144,7 @@ func TestParse(t *testing.T) { output: `event == "foo"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "foo", Ident: "event", @@ -144,6 +158,7 @@ func TestParse(t *testing.T) { output: `event.data.run_id == "xyz"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "xyz", Ident: "event.data.run_id", @@ -158,8 +173,10 @@ func TestParse(t *testing.T) { output: `event.data.id == "foo" && event.data.value > 100`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(2), Ands: []*Node{ { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: "foo", Ident: "event.data.id", @@ -167,6 +184,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -182,8 +200,10 @@ func TestParse(t *testing.T) { output: `event.data.float <= 3.141 && event.data.id == "foo" && event.data.value > 100`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(3), Ands: []*Node{ { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: 3.141, Ident: "event.data.float", @@ -191,6 +211,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: "foo", Ident: "event.data.id", @@ -198,6 +219,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -220,6 +242,7 @@ func TestParse(t *testing.T) { output: `event.data.a != "a"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "a", Ident: "event.data.a", @@ -233,6 +256,7 @@ func TestParse(t *testing.T) { output: `event.data.a == "a"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "a", Ident: "event.data.a", @@ -253,6 +277,7 @@ func TestParse(t *testing.T) { output: `event.data.id >= "ulid"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "ulid", Ident: "event.data.id", @@ -266,6 +291,7 @@ func TestParse(t *testing.T) { output: `event.data.id < "ulid"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "ulid", Ident: "event.data.id", @@ -279,6 +305,7 @@ func TestParse(t *testing.T) { output: `event.data.a != "a"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "a", Ident: "event.data.a", @@ -299,8 +326,10 @@ func TestParse(t *testing.T) { output: `event == "foo" || event == "bar"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "foo", Ident: "event", @@ -308,6 +337,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "bar", Ident: "event", @@ -323,8 +353,10 @@ func TestParse(t *testing.T) { output: `event == "foo" || event == "bar"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "foo", Ident: "event", @@ -332,6 +364,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "bar", Ident: "event", @@ -347,9 +380,11 @@ func TestParse(t *testing.T) { output: `a == 1 || (b == 2 && b != 3)`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ // Either { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(1), Ident: "a", @@ -357,8 +392,10 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(2), Ands: []*Node{ { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(2), Ident: "b", @@ -366,6 +403,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(3), Ident: "b", @@ -383,8 +421,10 @@ func TestParse(t *testing.T) { output: `event == "baz" || event == "foo" || event == "bar"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "baz", Ident: "event", @@ -392,6 +432,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "foo", Ident: "event", @@ -399,6 +440,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "bar", Ident: "event", @@ -415,10 +457,13 @@ func TestParse(t *testing.T) { expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ { + GroupID: newGroupID(2), Ands: []*Node{ { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: "order", Ident: "event.data.type", @@ -426,6 +471,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(500), Ident: "event.data.value", @@ -435,6 +481,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "preorder", Ident: "event.data.type", @@ -458,6 +505,7 @@ func TestParse(t *testing.T) { output: "event.data.value > 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -471,6 +519,7 @@ func TestParse(t *testing.T) { output: "event.data.value >= 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -484,6 +533,7 @@ func TestParse(t *testing.T) { output: "event.data.value < 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -497,6 +547,7 @@ func TestParse(t *testing.T) { output: "event.data.value <= 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -511,6 +562,7 @@ func TestParse(t *testing.T) { output: "event.data.value < 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -524,6 +576,7 @@ func TestParse(t *testing.T) { output: "event.data.value <= 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -537,6 +590,7 @@ func TestParse(t *testing.T) { output: "event.data.value > 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -550,6 +604,7 @@ func TestParse(t *testing.T) { output: "event.data.value >= 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -571,6 +626,7 @@ func TestParse(t *testing.T) { output: "event.data.value <= 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -584,6 +640,7 @@ func TestParse(t *testing.T) { output: "event.data.value < 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -597,6 +654,7 @@ func TestParse(t *testing.T) { output: "event.data.value >= 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -610,6 +668,7 @@ func TestParse(t *testing.T) { output: "event.data.value > 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -624,6 +683,7 @@ func TestParse(t *testing.T) { output: "event.data.value > 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -637,6 +697,7 @@ func TestParse(t *testing.T) { output: "event.data.value <= 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -650,6 +711,7 @@ func TestParse(t *testing.T) { output: "event.data.value < 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -664,6 +726,7 @@ func TestParse(t *testing.T) { output: "event.data.value < 100", expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(100), Ident: "event.data.value", @@ -684,8 +747,10 @@ func TestParse(t *testing.T) { output: `c == 3 || a == 1 || b == 2`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(3), Ident: "c", @@ -693,6 +758,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(1), Ident: "a", @@ -700,6 +766,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(2), Ident: "b", @@ -716,10 +783,13 @@ func TestParse(t *testing.T) { output: `(a == 1 && b == 2) || c == 3`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ { + GroupID: newGroupID(2), Ands: []*Node{ { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(1), Ident: "a", @@ -727,6 +797,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(2), Ident: "b", @@ -736,6 +807,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(3), Ident: "c", @@ -752,8 +824,10 @@ func TestParse(t *testing.T) { output: `a == 1 || (b == 2 && c == 3)`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(1), Ident: "a", @@ -761,8 +835,10 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(2), Ands: []*Node{ { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(2), Ident: "b", @@ -770,6 +846,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(3), Ident: "c", @@ -788,8 +865,10 @@ func TestParse(t *testing.T) { output: `c == 3 && (a == 1 || b == 2)`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(2), Ands: []*Node{ { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(3), Ident: "c", @@ -799,6 +878,7 @@ func TestParse(t *testing.T) { }, Ors: []*Node{ { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(1), Ident: "a", @@ -806,6 +886,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(2), Predicate: &Predicate{ Literal: int64(2), Ident: "b", @@ -822,8 +903,10 @@ func TestParse(t *testing.T) { output: `a == 1 && b == 2 && (c == 3 || d == 4)`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(3), Ands: []*Node{ { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(1), Ident: "a", @@ -831,6 +914,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(2), Ident: "b", @@ -840,6 +924,7 @@ func TestParse(t *testing.T) { }, Ors: []*Node{ { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(3), Ident: "c", @@ -847,6 +932,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(4), Ident: "d", @@ -868,8 +954,10 @@ func TestParse(t *testing.T) { output: `zz == 4 || (a == 1 && b == 2 && (c == 3 || d == 4)) || (z == 3 && e == 5 && (f == 6 || g == 7))`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Ors: []*Node{ { + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: int64(4), Ident: "zz", @@ -877,8 +965,10 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Ands: []*Node{ { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(1), Ident: "a", @@ -886,6 +976,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(2), Ident: "b", @@ -895,6 +986,7 @@ func TestParse(t *testing.T) { }, Ors: []*Node{ { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(3), Ident: "c", @@ -902,6 +994,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(4), Ident: "d", @@ -911,8 +1004,10 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Ands: []*Node{ { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(3), Ident: "z", @@ -920,6 +1015,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(5), Ident: "e", @@ -929,6 +1025,7 @@ func TestParse(t *testing.T) { }, Ors: []*Node{ { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(6), Ident: "f", @@ -936,6 +1033,7 @@ func TestParse(t *testing.T) { }, }, { + GroupID: newGroupID(3), Predicate: &Predicate{ Literal: int64(7), Ident: "g", @@ -951,6 +1049,7 @@ func TestParse(t *testing.T) { } assert(t, tests) + }) // TODO @@ -973,12 +1072,20 @@ func TestParse(t *testing.T) { assert(t, tests) }) */ - } func TestParse_LiftedVars(t *testing.T) { ctx := context.Background() + origRander := rander + // In tests, don't add any random data to group IDs. + rander = func(b []byte) (n int, err error) { + return 0, nil + } + t.Cleanup(func() { + rander = origRander + }) + cachingCelParser := NewCachingParser(newEnv(), nil) assert := func(t *testing.T, tests []parseTestInput) { @@ -1026,6 +1133,7 @@ func TestParse_LiftedVars(t *testing.T) { output: `event == "foo"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "foo", Ident: "event", @@ -1042,6 +1150,7 @@ func TestParse_LiftedVars(t *testing.T) { output: `event == "bar"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "bar", Ident: "event", @@ -1058,6 +1167,7 @@ func TestParse_LiftedVars(t *testing.T) { output: `event == "bar"`, expected: ParsedExpression{ Root: Node{ + GroupID: newGroupID(1), Predicate: &Predicate{ Literal: "bar", Ident: "event", diff --git a/tree.go b/tree.go index 2858eb1..f6f99e1 100644 --- a/tree.go +++ b/tree.go @@ -20,6 +20,7 @@ const ( // ART tree, while LTE operations may check against a b+-tree. type PredicateTree interface { Add(ctx context.Context, p ExpressionPart) error + Remove(ctx context.Context, p ExpressionPart) error Search(ctx context.Context, input any) (*Leaf, bool) } @@ -51,3 +52,13 @@ type ExpressionPart struct { Predicate Predicate Parsed *ParsedExpression } + +func (p ExpressionPart) Equals(n ExpressionPart) bool { + if p.GroupID != n.GroupID { + return false + } + if p.Predicate.String() != n.Predicate.String() { + return false + } + return p.Parsed.Evaluable.Expression() == n.Parsed.Evaluable.Expression() +} diff --git a/tree_art.go b/tree_art.go index ccaf752..9a2b31e 100644 --- a/tree_art.go +++ b/tree_art.go @@ -10,7 +10,8 @@ import ( ) var ( - ErrInvalidType = fmt.Errorf("invalid type for tree") + ErrInvalidType = fmt.Errorf("invalid type for tree") + ErrExpressionPartNotFound = fmt.Errorf("expression part not found") ) func newArtTree() PredicateTree { @@ -48,6 +49,36 @@ func (a *artTree) Search(ctx context.Context, input any) (*Leaf, bool) { return val.(*Leaf), true } +func (a *artTree) Remove(ctx context.Context, p ExpressionPart) error { + str, ok := p.Predicate.Literal.(string) + if !ok { + return ErrInvalidType + } + + key := artKeyFromString(str) + + // Don't allow multiple gorutines to modify the tree simultaneously. + a.lock.Lock() + defer a.lock.Unlock() + + val, ok := a.Tree.Search(key) + if !ok { + return ErrExpressionPartNotFound + } + + next := val.(*Leaf) + // Remove the expression part from the leaf. + for n, eval := range next.Evals { + if p.Equals(eval) { + next.Evals = append(next.Evals[:n], next.Evals[n+1:]...) + a.Insert(key, next) + return nil + } + } + + return ErrExpressionPartNotFound +} + func (a *artTree) Add(ctx context.Context, p ExpressionPart) error { str, ok := p.Predicate.Literal.(string) if !ok { From 3014841f23e4743f1bcfb66141fdfb715bc47e59 Mon Sep 17 00:00:00 2001 From: Tony Holdstock-Brown Date: Thu, 4 Jan 2024 21:43:27 -0800 Subject: [PATCH 2/2] Make GroupIDs deterministic based off of the Evaluable Identifier --- expr.go | 5 +++++ expr_test.go | 1 + groupid.go | 6 ++++++ parser.go | 23 ++++++++++++++++++++--- parser_test.go | 11 +++++++---- 5 files changed, 39 insertions(+), 7 deletions(-) diff --git a/expr.go b/expr.go index 18d7150..ded3f27 100644 --- a/expr.go +++ b/expr.go @@ -70,6 +70,11 @@ func NewAggregateEvaluator( } type Evaluable interface { + // Identifier returns a unique identifier for the evaluable item. If there are + // two instances of the same expression, the identifier should return a unique + // string for each instance of the expression (eg. for two pauses). + Identifier() string + // Expression returns an expression as a raw string. Expression() string } diff --git a/expr_test.go b/expr_test.go index 67fa282..a62a6f5 100644 --- a/expr_test.go +++ b/expr_test.go @@ -427,6 +427,7 @@ func TestAdd(t *testing.T) { type tex string func (e tex) Expression() string { return string(e) } +func (e tex) Identifier() string { return string(e) } func testBoolEvaluator(ctx context.Context, e Evaluable, input map[string]any) (bool, error) { env, _ := cel.NewEnv( diff --git a/groupid.go b/groupid.go index b4591bd..86a586b 100644 --- a/groupid.go +++ b/groupid.go @@ -13,6 +13,8 @@ type groupID [8]byte var rander = rand.Read +type RandomReader func(p []byte) (n int, err error) + func (g groupID) String() string { return hex.EncodeToString(g[:]) } @@ -22,6 +24,10 @@ func (g groupID) Size() uint16 { } func newGroupID(size uint16) groupID { + return newGroupIDWithReader(size, rander) +} + +func newGroupIDWithReader(size uint16, rander RandomReader) groupID { id := make([]byte, 8) binary.NativeEndian.PutUint16(id, size) _, _ = rander(id[2:]) diff --git a/parser.go b/parser.go index 286b308..e5f0ff6 100644 --- a/parser.go +++ b/parser.go @@ -2,7 +2,10 @@ package expr import ( "context" + "crypto/sha256" + "encoding/binary" "fmt" + "math/rand" "strconv" "strings" @@ -52,6 +55,8 @@ func NewTreeParser(ep CELParser) (TreeParser, error) { type parser struct { ep CELParser + + rander RandomReader } func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, error) { @@ -60,6 +65,17 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, return nil, issues.Err() } + if p.rander == nil { + // Create a new deterministic random reader based off of the evaluable's identifier. + // This means that every time we parse an expression with the given identifier, the + // group IDs will be deterministic as the randomness is sourced from the ID. + // + // We only overwrite this if rander is not nil so that we can inject rander during tests. + digest := sha256.Sum256([]byte(eval.Identifier())) + seed := int64(binary.NativeEndian.Uint64(digest[:8])) + p.rander = rand.New(rand.NewSource(seed)).Read + } + node := newNode() _, err := navigateAST( expr{ @@ -67,6 +83,7 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, }, node, vars, + p.rander, ) if err != nil { return nil, err @@ -327,7 +344,7 @@ type expr struct { // It does this by iterating through the expression, amending the current `group` until // an or expression is found. When an or expression is found, we create another group which // is mutated by the iteration. -func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) { +func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]*Node, error) { // on the very first call to navigateAST, ensure that we set the first node // inside the nodemap. result := []*Node{} @@ -376,7 +393,7 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) { newParent := newNode() // For each item in the stack, recurse into that AST. - _, err := navigateAST(or, newParent, vars) + _, err := navigateAST(or, newParent, vars, rand) if err != nil { return nil, err } @@ -428,7 +445,7 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) { total += 1 } - parent.GroupID = newGroupID(uint16(total)) + parent.GroupID = newGroupIDWithReader(uint16(total), rand) // For each sub-group, add the same group IDs to children if there's no nesting. for n, item := range parent.Ands { if len(item.Ands) == 0 && len(item.Ors) == 0 && item.Predicate != nil { diff --git a/parser_test.go b/parser_test.go index c320044..9a7b66a 100644 --- a/parser_test.go +++ b/parser_test.go @@ -45,11 +45,12 @@ func TestParse(t *testing.T) { t.Helper() for _, test := range tests { - parser, err := newParser() + p, err := newParser() + p.(*parser).rander = rander require.NoError(t, err) eval := tex(test.input) - actual, err := parser.Parse(ctx, eval) + actual, err := p.Parse(ctx, eval) require.NotNil(t, actual.Root.GroupID) @@ -1092,10 +1093,12 @@ func TestParse_LiftedVars(t *testing.T) { t.Helper() for _, test := range tests { - parser, err := NewTreeParser(cachingCelParser) + p, err := NewTreeParser(cachingCelParser) + // overwrite rander so that the parser uses the same nil bytes + p.(*parser).rander = rander require.NoError(t, err) eval := tex(test.input) - actual, err := parser.Parse(ctx, eval) + actual, err := p.Parse(ctx, eval) // Shortcut to ensure the evaluable instance matches if test.expected.Evaluable == nil {