diff --git a/engine.go b/engine.go index 0bfe367..c8e21f0 100644 --- a/engine.go +++ b/engine.go @@ -22,6 +22,7 @@ const ( type MatchingEngine interface { // Type returns the EngineType Type() EngineType + // Match takes an input event, containing key:value pairs of data, and // matches the given data to any ExpressionParts stored in the engine. // @@ -29,7 +30,8 @@ type MatchingEngine interface { // expression parts received. Some may return false positives, but // each MatchingEngine should NEVER omit ExpressionParts which match // the given input. - Match(ctx context.Context, input map[string]any) ([]*StoredExpressionPart, error) + Match(ctx context.Context, input map[string]any) (matched []*StoredExpressionPart, err error) + // Add adds a new expression part to the matching engine for future matches. Add(ctx context.Context, p ExpressionPart) error // Remove removes an expression part from the matching engine, ensuring that the @@ -44,7 +46,7 @@ type MatchingEngine interface { // ignoring the variable name. Note that each MatchingEngine should NEVER // omit ExpressionParts which match the given input; false positives are okay, // but not returning valid matches must be impossible. - Search(ctx context.Context, variable string, input any) []*StoredExpressionPart + Search(ctx context.Context, variable string, input any) (matched []*StoredExpressionPart) } // Leaf represents the leaf within a tree. This stores all expressions diff --git a/engine_null.go b/engine_null.go index b03f8ee..192fb84 100644 --- a/engine_null.go +++ b/engine_null.go @@ -32,10 +32,9 @@ func (n *nullLookup) Type() EngineType { return EngineTypeNullMatch } -func (n *nullLookup) Match(ctx context.Context, data map[string]any) ([]*StoredExpressionPart, error) { - +func (n *nullLookup) Match(ctx context.Context, data map[string]any) (matched []*StoredExpressionPart, err error) { l := &sync.Mutex{} - found := []*StoredExpressionPart{} + matched = []*StoredExpressionPart{} eg := errgroup.Group{} for item := range n.paths { @@ -55,17 +54,21 @@ func (n *nullLookup) Match(ctx context.Context, data map[string]any) ([]*StoredE // This matches null, nil (as null), and any non-null items. l.Lock() - found = append(found, n.Search(ctx, path, res[0])...) + + // XXX: This engine hasn't been updated with denied items for !=. It needs consideration + // in how to handle these cases appropriately. + found := n.Search(ctx, path, res[0]) + matched = append(matched, found...) l.Unlock() return nil }) } - return found, eg.Wait() + return matched, eg.Wait() } -func (n *nullLookup) Search(ctx context.Context, variable string, input any) []*StoredExpressionPart { +func (n *nullLookup) Search(ctx context.Context, variable string, input any) (matched []*StoredExpressionPart) { if input == nil { // The input data is null, so the only items that can match are equality // comparisons to null. diff --git a/engine_number.go b/engine_number.go index 0231f7d..b6a3547 100644 --- a/engine_number.go +++ b/engine_number.go @@ -39,9 +39,9 @@ func (n numbers) Type() EngineType { return EngineTypeBTree } -func (n *numbers) Match(ctx context.Context, input map[string]any) ([]*StoredExpressionPart, error) { +func (n *numbers) Match(ctx context.Context, input map[string]any) (matched []*StoredExpressionPart, err error) { l := &sync.Mutex{} - found := []*StoredExpressionPart{} + matched = []*StoredExpressionPart{} eg := errgroup.Group{} for item := range n.paths { @@ -72,28 +72,27 @@ func (n *numbers) Match(ctx context.Context, input map[string]any) ([]*StoredExp // This matches null, nil (as null), and any non-null items. l.Lock() - found = append(found, n.Search(ctx, path, val)...) + found := n.Search(ctx, path, val) + matched = append(matched, found...) l.Unlock() return nil }) } - err := eg.Wait() - - return found, err + return matched, eg.Wait() } // Search returns all ExpressionParts which match the given input, ignoring the variable name // entirely. -func (n *numbers) Search(ctx context.Context, variable string, input any) []*StoredExpressionPart { +func (n *numbers) Search(ctx context.Context, variable string, input any) (matched []*StoredExpressionPart) { n.lock.RLock() defer n.lock.RUnlock() - var ( - val float64 - found = []*StoredExpressionPart{} - ) + // initialize matched + matched = []*StoredExpressionPart{} + + var val float64 switch v := input.(type) { case int: @@ -114,7 +113,7 @@ func (n *numbers) Search(ctx context.Context, variable string, input any) []*Sto continue } // This is a candidatre. - found = append(found, m) + matched = append(matched, m) } } @@ -130,7 +129,7 @@ func (n *numbers) Search(ctx context.Context, variable string, input any) []*Sto continue } // This is a candidatre. - found = append(found, m) + matched = append(matched, m) } return true }) @@ -147,12 +146,12 @@ func (n *numbers) Search(ctx context.Context, variable string, input any) []*Sto continue } // This is a candidatre. - found = append(found, m) + matched = append(matched, m) } return true }) - return found + return matched } func (n *numbers) Add(ctx context.Context, p ExpressionPart) error { diff --git a/engine_stringmap.go b/engine_stringmap.go index cc65325..ed45229 100644 --- a/engine_stringmap.go +++ b/engine_stringmap.go @@ -14,12 +14,16 @@ import ( func newStringEqualityMatcher() MatchingEngine { return &stringLookup{ - lock: &sync.RWMutex{}, - vars: map[string]struct{}{}, - strings: map[string][]*StoredExpressionPart{}, + lock: &sync.RWMutex{}, + vars: map[string]struct{}{}, + equality: variableMap{}, + inequality: inequalityMap{}, } } +type variableMap map[string][]*StoredExpressionPart +type inequalityMap map[string]variableMap + // stringLookup represents a very dumb lookup for string equality matching within // expressions. // @@ -38,9 +42,15 @@ type stringLookup struct { // vars stores variable names seen within expressions. vars map[string]struct{} - // strings stores all strings referenced within expressions, mapped to the expression part. + // equality stores all strings referenced within expressions, mapped to the expression part. // this performs string equality lookups. - strings map[string][]*StoredExpressionPart + equality variableMap + + // inequality stores all variables referenced within inequality checks mapped to the value, + // which is then mapped to expression parts. + // + // this lets us quickly map neq in a fast manner + inequality inequalityMap } func (s stringLookup) Type() EngineType { @@ -49,9 +59,11 @@ func (s stringLookup) Type() EngineType { func (n *stringLookup) Match(ctx context.Context, input map[string]any) ([]*StoredExpressionPart, error) { l := &sync.Mutex{} - found := []*StoredExpressionPart{} + + matched := []*StoredExpressionPart{} eg := errgroup.Group{} + // First, handle equality matching. for item := range n.vars { path := item eg.Go(func() error { @@ -60,36 +72,103 @@ func (n *stringLookup) Match(ctx context.Context, input map[string]any) ([]*Stor return err } - res := x.Get(input) - if len(res) == 0 { - return nil + // default to an empty string + str := "" + if res := x.Get(input); len(res) > 0 { + if value, ok := res[0].(string); ok { + str = value + } } - str, ok := res[0].(string) - if !ok { - return nil + + m := n.equalitySearch(ctx, path, str) + + l.Lock() + matched = append(matched, m...) + l.Unlock() + return nil + }) + } + + // Then, iterate through the inequality matches. + for item := range n.inequality { + path := item + eg.Go(func() error { + x, err := jp.ParseString(path) + if err != nil { + return err } - // This matches null, nil (as null), and any non-null items. + // default to an empty string + str := "" + if res := x.Get(input); len(res) > 0 { + if value, ok := res[0].(string); ok { + str = value + } + } + + m := n.inequalitySearch(ctx, path, str) + l.Lock() - found = append(found, n.Search(ctx, path, str)...) + matched = append(matched, m...) l.Unlock() return nil }) } - return found, eg.Wait() + return matched, eg.Wait() } // Search returns all ExpressionParts which match the given input, ignoring the variable name // entirely. -func (n *stringLookup) Search(ctx context.Context, variable string, input any) []*StoredExpressionPart { - n.lock.RLock() - defer n.lock.RUnlock() +// +// Note that Search does not match inequality items. +func (n *stringLookup) Search(ctx context.Context, variable string, input any) (matched []*StoredExpressionPart) { str, ok := input.(string) if !ok { return nil } - return n.strings[n.hash(str)] + + return n.equalitySearch(ctx, variable, str) + +} + +func (n *stringLookup) equalitySearch(ctx context.Context, variable string, input string) (matched []*StoredExpressionPart) { + n.lock.RLock() + defer n.lock.RUnlock() + + hashedInput := n.hash(input) + + // Iterate through all matching values, and only take those expressions which match our + // current variable name. + filtered := make([]*StoredExpressionPart, len(n.equality[hashedInput])) + i := 0 + for _, part := range n.equality[hashedInput] { + if part.Ident != nil && *part.Ident != variable { + // The variables don't match. + continue + } + filtered[i] = part + i++ + } + filtered = filtered[0:i] + + return filtered +} + +func (n *stringLookup) inequalitySearch(ctx context.Context, variable string, input string) (matched []*StoredExpressionPart) { + n.lock.RLock() + defer n.lock.RUnlock() + + hashedInput := n.hash(input) + + results := []*StoredExpressionPart{} + for value, exprs := range n.inequality[variable] { + if value == hashedInput { + continue + } + results = append(results, exprs...) + } + return results } // hash hashes strings quickly via xxhash. this provides a _somewhat_ collision-free @@ -102,50 +181,98 @@ func (n *stringLookup) hash(input string) string { } func (n *stringLookup) Add(ctx context.Context, p ExpressionPart) error { - if p.Predicate.Operator != operators.Equals { - return fmt.Errorf("StringHash engines only support string equality") - } + // Primarily, we match `$string == lit` and `$string != lit`. + // + // Equality operators are easy: link the matching string to + // expressions that are candidates. + switch p.Predicate.Operator { + case operators.Equals: + n.lock.Lock() + defer n.lock.Unlock() + val := n.hash(p.Predicate.LiteralAsString()) + + n.vars[p.Predicate.Ident] = struct{}{} - n.lock.Lock() - defer n.lock.Unlock() - val := n.hash(p.Predicate.LiteralAsString()) + if _, ok := n.equality[val]; !ok { + n.equality[val] = []*StoredExpressionPart{p.ToStored()} + return nil + } + n.equality[val] = append(n.equality[val], p.ToStored()) - n.vars[p.Predicate.Ident] = struct{}{} + case operators.NotEquals: + n.lock.Lock() + defer n.lock.Unlock() + val := n.hash(p.Predicate.LiteralAsString()) - if _, ok := n.strings[val]; !ok { - n.strings[val] = []*StoredExpressionPart{p.ToStored()} + // First, add the variable to inequality + if _, ok := n.inequality[p.Predicate.Ident]; !ok { + n.inequality[p.Predicate.Ident] = variableMap{ + val: []*StoredExpressionPart{p.ToStored()}, + } + return nil + } + + n.inequality[p.Predicate.Ident][val] = append(n.inequality[p.Predicate.Ident][val], p.ToStored()) return nil + default: + return fmt.Errorf("StringHash engines only support string equality/inequality") } - n.strings[val] = append(n.strings[val], p.ToStored()) return nil } func (n *stringLookup) Remove(ctx context.Context, p ExpressionPart) error { - if p.Predicate.Operator != operators.Equals { - return fmt.Errorf("StringHash engines only support string equality") - } + switch p.Predicate.Operator { + case operators.Equals: + n.lock.Lock() + defer n.lock.Unlock() - n.lock.Lock() - defer n.lock.Unlock() + val := n.hash(p.Predicate.LiteralAsString()) - val := n.hash(p.Predicate.LiteralAsString()) + coll, ok := n.equality[val] + if !ok { + // This could not exist as there's nothing mapping this variable for + // the given event name. + return ErrExpressionPartNotFound + } + + // Remove the expression part from the leaf. + for i, eval := range coll { + if p.EqualsStored(eval) { + coll = append(coll[:i], coll[i+1:]...) + n.equality[val] = coll + return nil + } + } - coll, ok := n.strings[val] - if !ok { - // This could not exist as there's nothing mapping this variable for - // the given event name. return ErrExpressionPartNotFound - } - // Remove the expression part from the leaf. - for i, eval := range coll { - if p.EqualsStored(eval) { - coll = append(coll[:i], coll[i+1:]...) - n.strings[val] = coll + case operators.NotEquals: + n.lock.Lock() + defer n.lock.Unlock() + + val := n.hash(p.Predicate.LiteralAsString()) + + // If the var isn't found, we can't remove. + if _, ok := n.inequality[p.Predicate.Ident]; !ok { + return ErrExpressionPartNotFound + } + + // then merge the expression into the value that the expression has. + if _, ok := n.inequality[p.Predicate.Ident][val]; !ok { return nil } - } - return ErrExpressionPartNotFound + for i, eval := range n.inequality[p.Predicate.Ident][val] { + if p.EqualsStored(eval) { + n.inequality[p.Predicate.Ident][val] = append(n.inequality[p.Predicate.Ident][val][:i], n.inequality[p.Predicate.Ident][val][i+1:]...) + return nil + } + } + + return ErrExpressionPartNotFound + + default: + return fmt.Errorf("StringHash engines only support string equality/inequality") + } } diff --git a/engine_stringmap_test.go b/engine_stringmap_test.go index f357d37..ae57e38 100644 --- a/engine_stringmap_test.go +++ b/engine_stringmap_test.go @@ -34,23 +34,38 @@ func TestEngineStringmap(t *testing.T) { }, } - t.Run("It adds strings", func(t *testing.T) { - var err error + // Test inequality + d := ExpressionPart{ + Predicate: &Predicate{ + Ident: "async.data.neq", + Literal: "neq-1", + Operator: operators.NotEquals, + }, + } + e := ExpressionPart{ + Predicate: &Predicate{ + Ident: "async.data.neq", + Literal: "neq-2", + Operator: operators.NotEquals, + }, + } - err = s.Add(ctx, a) - require.NoError(t, err) + // Adding expressions works + var err error - t.Run("Adding the same string twice", func(t *testing.T) { - err = s.Add(ctx, b) - require.NoError(t, err) - require.Equal(t, 2, len(s.strings[s.hash("123")])) - }) + err = s.Add(ctx, a) + require.NoError(t, err) - // A different expression - err = s.Add(ctx, c) + t.Run("Adding the same string twice", func(t *testing.T) { + err = s.Add(ctx, b) require.NoError(t, err) + require.Equal(t, 2, len(s.equality[s.hash("123")])) }) + // A different expression + err = s.Add(ctx, c) + require.NoError(t, err) + t.Run("It searches strings", func(t *testing.T) { parts := s.Search(ctx, "async.data.id", "123") require.Equal(t, 2, len(parts)) @@ -60,28 +75,160 @@ func TestEngineStringmap(t *testing.T) { require.EqualValues(t, part.PredicateID, b.Hash()) } - t.Run("It ignores variable names (for now)", func(t *testing.T) { + t.Run("It handles variable names", func(t *testing.T) { parts = s.Search(ctx, "this doesn't matter", "123") - require.Equal(t, 2, len(parts)) - for _, part := range parts { - require.EqualValues(t, part.PredicateID, a.Hash()) - require.EqualValues(t, part.PredicateID, b.Hash()) - } + require.Equal(t, 0, len(parts)) }) parts = s.Search(ctx, "async.data.another", "456") require.Equal(t, 1, len(parts)) }) - t.Run("It matches data", func(t *testing.T) { + // Inequality + err = s.Add(ctx, d) + require.NoError(t, err) + err = s.Add(ctx, e) + require.NoError(t, err) + + t.Run("inequality", func(t *testing.T) { + t.Run("first case: neq-1", func(t *testing.T) { + parts, err := s.Match(ctx, map[string]any{ + "async": map[string]any{ + "data": map[string]any{"neq": "neq-1"}, + }, + }) + require.NoError(t, err) + require.Equal(t, 1, len(parts)) + require.EqualValues(t, parts[0].PredicateID, e.Hash()) + }) + + t.Run("second case: neq-1", func(t *testing.T) { + parts, err := s.Match(ctx, map[string]any{ + "async": map[string]any{ + "data": map[string]any{"neq": "neq-2"}, + }, + }) + require.NoError(t, err) + require.Equal(t, 1, len(parts)) + require.EqualValues(t, parts[0].PredicateID, d.Hash()) + }) + + t.Run("third case: both", func(t *testing.T) { + parts, err := s.Match(ctx, map[string]any{ + "async": map[string]any{ + "data": map[string]any{"neq": "both"}, + }, + }) + require.NoError(t, err) + require.Equal(t, 2, len(parts)) + }) + }) + + t.Run("It matches data, including neq", func(t *testing.T) { + found, err := s.Match(ctx, map[string]any{ + "async": map[string]any{ + "data": map[string]any{ + "id": "123", + "neq": "lol", + }, + }, + }) + require.NoError(t, err) + require.Equal(t, 4, len(found)) // matching plus inequality + }) + + t.Run("It matches data with null neq", func(t *testing.T) { found, err := s.Match(ctx, map[string]any{ "async": map[string]any{ "data": map[string]any{ "id": "123", + // by not including neq, we ensure we test against null matches. }, }, }) require.NoError(t, err) - require.Equal(t, 2, len(found)) + require.Equal(t, 4, len(found)) // matching plus inequality + }) + +} + +func TestEngineStringmap_DuplicateValues(t *testing.T) { + ctx := context.Background() + s := newStringEqualityMatcher().(*stringLookup) + a := ExpressionPart{ + Predicate: &Predicate{ + Ident: "async.data.var_a", + Literal: "123", + Operator: operators.Equals, + }, + } + b := ExpressionPart{ + Predicate: &Predicate{ + Ident: "async.data.var_b", + Literal: "123", + Operator: operators.Equals, + }, + } + err := s.Add(ctx, a) + require.NoError(t, err) + err = s.Add(ctx, b) + require.NoError(t, err) + + // It only matches var B + parts := s.Search(ctx, "async.data.var_b", "123") + require.Equal(t, 1, len(parts)) + +} + +func TestEngineStringmap_DuplicateNeq(t *testing.T) { + ctx := context.Background() + s := newStringEqualityMatcher().(*stringLookup) + a := ExpressionPart{ + Predicate: &Predicate{ + Ident: "async.data.var_a", + Literal: "a", + Operator: operators.Equals, + }, + } + b := ExpressionPart{ + Predicate: &Predicate{ + Ident: "async.data.var_b", + Literal: "b", + Operator: operators.Equals, + }, + } + c := ExpressionPart{ + Predicate: &Predicate{ + Ident: "async.data.var_c", + Literal: "123", + Operator: operators.NotEquals, + }, + } + err := s.Add(ctx, a) + require.NoError(t, err) + err = s.Add(ctx, b) + require.NoError(t, err) + err = s.Add(ctx, c) + require.NoError(t, err) + + parts, err := s.Match(ctx, map[string]any{ + "async": map[string]any{ + "data": map[string]any{ + "var_a": "a", + "var_b": "nah", + }, + }, }) + + require.NoError(t, err) + require.Equal(t, 2, len(parts)) + for _, v := range parts { + // Never matches B, as B isn't complete. + require.NotEqualValues(t, v.PredicateID, b.Hash()) + require.Contains(t, []uint64{ + a.Hash(), + c.Hash(), + }, v.PredicateID) + } + } diff --git a/expr.go b/expr.go index 22deb78..401b106 100644 --- a/expr.go +++ b/expr.go @@ -171,27 +171,17 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu s sync.Mutex ) - // TODO: Concurrently match constant expressions using a semaphore for capacity. + eg := errgroup.Group{} - // Match constant expressions always. a.lock.RLock() - constantEvals := make([]Evaluable, len(a.constants)) - n := 0 for uuid := range a.constants { - if eval, ok := a.evals[uuid]; ok { - constantEvals[n] = eval - n++ - } - } - a.lock.RUnlock() - - eg := errgroup.Group{} - for _, item := range constantEvals { - if item == nil { + item, ok := a.evals[uuid] + if !ok || item == nil { continue } if err := a.sem.Acquire(ctx, 1); err != nil { + a.lock.RUnlock() return result, matched, err } @@ -230,6 +220,7 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu return nil }) } + a.lock.RUnlock() if werr := eg.Wait(); werr != nil { err = errors.Join(err, werr) @@ -240,18 +231,6 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu err = errors.Join(err, merr) } - // Load all evaluable instances directly from the match - a.lock.RLock() - n = 0 - evaluables := make([]Evaluable, len(matches)) - for _, el := range matches { - if eval, ok := a.evals[el.Parsed.EvaluableID]; ok { - evaluables[n] = eval - n++ - } - } - a.lock.RUnlock() - // Each match here is a potential success. When other trees and operators which are walkable // are added (eg. >= operators on strings), ensure that we find the correct number of matches // for each group ID and then skip evaluating expressions if the number of matches is <= the group @@ -259,17 +238,18 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu seenMu := &sync.Mutex{} seen := map[uuid.UUID]struct{}{} - eg = errgroup.Group{} - for _, match := range evaluables { - if match == nil { + a.lock.RLock() + for _, expr := range matches { + eval, ok := a.evals[expr.Parsed.EvaluableID] + if !ok || eval == nil { continue } if err := a.sem.Acquire(ctx, 1); err != nil { + a.lock.RUnlock() return result, matched, err } - expr := match eg.Go(func() error { defer a.sem.Release(1) defer func() { @@ -281,11 +261,11 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu }() seenMu.Lock() - if _, ok := seen[expr.GetID()]; ok { + if _, ok := seen[eval.GetID()]; ok { seenMu.Unlock() return nil } else { - seen[expr.GetID()] = struct{}{} + seen[eval.GetID()] = struct{}{} seenMu.Unlock() } @@ -294,19 +274,20 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu // NOTE: We don't need to add lifted expression variables, // because match.Parsed.Evaluable() returns the original expression // string. - ok, evalerr := a.eval(ctx, expr, data) + ok, evalerr := a.eval(ctx, eval, data) if evalerr != nil { return evalerr } if ok { s.Lock() - result = append(result, expr) + result = append(result, eval) s.Unlock() } return nil }) } + a.lock.RUnlock() if werr := eg.Wait(); werr != nil { err = errors.Join(err, werr) @@ -329,11 +310,14 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([ // else we know a required comparason did not match. // // Note that having a count >= the group ID value does not guarantee that the expression is valid. - counts := map[groupID]int{} + // + // Note that we break this down per evaluable ID (UUID) + totalCounts := map[uuid.UUID]map[groupID]int{} // Store all expression parts per group ID for returning. - found := map[groupID][]*StoredExpressionPart{} + found := map[uuid.UUID]map[groupID][]*StoredExpressionPart{} for _, engine := range a.engines { + // we explicitly ignore the deny path for now. matched, err := engine.Match(ctx, data) if err != nil { return nil, err @@ -341,43 +325,64 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([ // Add all found items from the engine to the above list. for _, eval := range matched { - counts[eval.GroupID] += 1 + idCount, idFound := totalCounts[eval.Parsed.EvaluableID], found[eval.Parsed.EvaluableID] - if _, ok := found[eval.GroupID]; !ok { - found[eval.GroupID] = []*StoredExpressionPart{} + if idCount == nil { + idCount = map[groupID]int{} + idFound = map[groupID][]*StoredExpressionPart{} } - found[eval.GroupID] = append(found[eval.GroupID], eval) + + idCount[eval.GroupID] += 1 + if _, ok := idFound[eval.GroupID]; !ok { + idFound[eval.GroupID] = []*StoredExpressionPart{} + } + idFound[eval.GroupID] = append(idFound[eval.GroupID], eval) + + // Update mapping + totalCounts[eval.Parsed.EvaluableID] = idCount + found[eval.Parsed.EvaluableID] = idFound } + } - // Validate that groups meet the minimum size. - for groupID, matchingCount := range counts { - requiredSize := int(groupID.Size()) // The total req size from the group ID + seen := map[uuid.UUID]struct{}{} - if matchingCount >= requiredSize { - // The matching count met the group size; all results are safe. - result = append(result, found[groupID]...) - continue - } + // Validate that groups meet the minimum size. + for evalID, counts := range totalCounts { + for groupID, matchingCount := range counts { + + requiredSize := int(groupID.Size()) // The total req size from the group ID + + if matchingCount >= requiredSize { + for _, i := range found[evalID][groupID] { + if _, ok := seen[i.Parsed.EvaluableID]; ok { + continue + } + seen[i.Parsed.EvaluableID] = struct{}{} + result = append(result, i) + } + continue + } - // If this is a partial eval, always add it if there's a match for now. + // If this is a partial eval, always add it if there's a match for now. - // The GroupID required more comparisons to equate to true than - // we had, so this could never evaluate to true. Skip this. - // - // NOTE: We currently don't add items with OR predicates to the - // matching engine, so we cannot use group sizes if the expr part - // has an OR. - for _, i := range found[groupID] { - // if this is purely aggregateable, we're safe to rely on group IDs. + // The GroupID required more comparisons to equate to true than + // we had, so this could never evaluate to true. Skip this. // - // So, we only need to care if this expression is mixed. If it's mixed, - // we can ignore group IDs for the time being. - if _, ok := a.mixed[i.Parsed.EvaluableID]; ok { - // this wasn't fully aggregatable so evaluate it. - result = append(result, i) - } + // NOTE: We currently don't add items with OR predicates to the + // matching engine, so we cannot use group sizes if the expr part + // has an OR. + for _, i := range found[evalID][groupID] { + // if this is purely aggregateable, we're safe to rely on group IDs. + // + // So, we only need to care if this expression is mixed. If it's mixed, + // we can ignore group IDs for the time being. + if _, ok := a.mixed[i.Parsed.EvaluableID]; ok { + // this wasn't fully aggregatable so evaluate it. + result = append(result, i) + } + } } } @@ -636,8 +641,8 @@ func engineType(p Predicate) EngineType { // return EngineTypeNone return EngineTypeBTree case string: - if p.Operator == operators.Equals { - // StringHash is only used for matching on equality. + if p.Operator == operators.Equals || p.Operator == operators.NotEquals { + // StringHash is only used for matching on in/equality. return EngineTypeStringHash } case nil: diff --git a/expr_test.go b/expr_test.go index d12c0fb..66e0236 100644 --- a/expr_test.go +++ b/expr_test.go @@ -133,6 +133,70 @@ func TestEvaluate_Strings(t *testing.T) { addOtherExpressions(n, e, loader) + require.EqualValues(t, n+1, e.Len()) + // These should all be fast matches. + require.EqualValues(t, n+1, e.FastLen()) + require.EqualValues(t, 0, e.MixedLen()) + require.EqualValues(t, 0, e.SlowLen()) + + t.Run("It matches items", func(t *testing.T) { + pre := time.Now() + evals, executed, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "account_id": "yes", + "match": "true", + }, + }, + }) + total := time.Since(pre) + fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) + fmt.Printf("Matched in %v ms (%d)\n", total.Milliseconds(), executed) + + require.NoError(t, err) + require.EqualValues(t, []Evaluable{expected}, evals) + // We may match more than 1 as the string matcher engine + // returns false positives + require.Equal(t, executed, int32(1)) + }) + + t.Run("It handles non-matching data", func(t *testing.T) { + pre := time.Now() + evals, executed, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "account_id": "yes", + "match": "no", + }, + }, + }) + total := time.Since(pre) + fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) + fmt.Printf("Matched in %v ms (%d)\n", total.Milliseconds(), executed) + + require.NoError(t, err) + require.EqualValues(t, 0, len(evals)) + require.EqualValues(t, 0, executed) + }) +} + +func TestEvaluate_Strings_Inequality(t *testing.T) { + ctx := context.Background() + parser := NewTreeParser(NewCachingCompiler(newEnv(), nil)) + + expected := tex(`event.data.account_id == "yes" && event.data.neq != "neq"`) + loader := newEvalLoader() + loader.AddEval(expected) + + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) + + _, err := e.Add(ctx, expected) + require.NoError(t, err) + + n := 100_000 + + addOtherExpressions(n, e, loader) + require.EqualValues(t, n+1, e.Len()) t.Run("It matches items", func(t *testing.T) { @@ -142,6 +206,7 @@ func TestEvaluate_Strings(t *testing.T) { "data": map[string]any{ "account_id": "yes", "match": "true", + "neq": "nah", }, }, }) @@ -150,6 +215,7 @@ func TestEvaluate_Strings(t *testing.T) { fmt.Printf("Matched in %v ms (%d)\n", total.Milliseconds(), matched) require.NoError(t, err) + require.EqualValues(t, 1, len(evals)) require.EqualValues(t, []Evaluable{expected}, evals) // We may match more than 1 as the string matcher engine // returns false positives @@ -163,6 +229,7 @@ func TestEvaluate_Strings(t *testing.T) { "data": map[string]any{ "account_id": "yes", "match": "no", + "neq": "nah", }, }, }) @@ -171,8 +238,8 @@ func TestEvaluate_Strings(t *testing.T) { fmt.Printf("Matched in %v ms\n", total.Milliseconds()) require.NoError(t, err) - require.EqualValues(t, 0, len(evals)) - require.EqualValues(t, 0, matched) + require.EqualValues(t, 1, len(evals)) + require.EqualValues(t, 1, matched) }) } @@ -317,7 +384,7 @@ func TestEvaluate_Concurrently(t *testing.T) { _, err := e.Add(ctx, expected) require.NoError(t, err) - addOtherExpressions(100_000, e, loader) + addOtherExpressions(1_000, e, loader) t.Run("It matches items", func(t *testing.T) { wg := sync.WaitGroup{} @@ -735,11 +802,12 @@ func TestAddRemove(t *testing.T) { e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) ok, err := e.Add(ctx, loader.AddEval(tex(`event.data.foo == "yea" && event.data.bar != "baz"`))) require.NoError(t, err) - require.Equal(t, ok, float64(0.5)) + // now fully aggregated + require.Equal(t, ok, float64(1)) require.Equal(t, 1, e.Len()) require.Equal(t, 0, e.SlowLen()) - require.Equal(t, 0, e.FastLen()) - require.Equal(t, 1, e.MixedLen()) + require.Equal(t, 1, e.FastLen()) + require.Equal(t, 0, e.MixedLen()) // Matching this expr should now fail. eval, count, err := e.Evaluate(ctx, map[string]any{ @@ -765,7 +833,7 @@ func TestAddRemove(t *testing.T) { }, }) - require.EqualValues(t, 1, count) + require.EqualValues(t, 0, count) require.EqualValues(t, 0, len(eval)) require.NoError(t, err) }) @@ -1091,21 +1159,27 @@ func testBoolEvaluator(ctx context.Context, e Evaluable, input map[string]any) ( } func addOtherExpressions(n int, e AggregateEvaluator, loader *evalLoader) { + + r := rand.New(rand.NewSource(123)) + var l sync.Mutex + ctx := context.Background() wg := sync.WaitGroup{} for i := 0; i < n; i++ { - wg.Add(1) //nolint:all + wg.Add(1) go func() { defer wg.Done() byt := make([]byte, 8) - _, err := rand.Read(byt) + l.Lock() + _, err := r.Read(byt) + l.Unlock() if err != nil { panic(err) } str := hex.EncodeToString(byt) - expr := tex(fmt.Sprintf(`event.data.account_id == "%s"`, str)) + expr := tex(fmt.Sprintf(`event.data.account_id == "%s" && event.data.neq != "neq"`, str)) loader.AddEval(expr) _, err = e.Add(ctx, expr) if err != nil { diff --git a/parser.go b/parser.go index f8ef615..b982da2 100644 --- a/parser.go +++ b/parser.go @@ -2,7 +2,6 @@ package expr import ( "context" - "crypto/sha256" "encoding/binary" "fmt" "math/rand" @@ -95,8 +94,7 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, // // We only overwrite this if rander is not nil so that we can inject rander during tests. id := eval.GetID() - digest := sha256.Sum256(id[:]) - seed := int64(binary.NativeEndian.Uint64(digest[:8])) + seed := int64(binary.NativeEndian.Uint64(id[:8])) r = rand.New(rand.NewSource(seed)).Read }