Skip to content

Commit

Permalink
Make GroupIDs deterministic based off of the Evaluable Identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyhb committed Jan 5, 2024
1 parent db40612 commit 3014841
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 7 deletions.
5 changes: 5 additions & 0 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ func NewAggregateEvaluator(
}

type Evaluable interface {
// Identifier returns a unique identifier for the evaluable item. If there are
// two instances of the same expression, the identifier should return a unique
// string for each instance of the expression (eg. for two pauses).
Identifier() string

// Expression returns an expression as a raw string.
Expression() string
}
Expand Down
1 change: 1 addition & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ func TestAdd(t *testing.T) {
type tex string

func (e tex) Expression() string { return string(e) }
func (e tex) Identifier() string { return string(e) }

func testBoolEvaluator(ctx context.Context, e Evaluable, input map[string]any) (bool, error) {
env, _ := cel.NewEnv(
Expand Down
6 changes: 6 additions & 0 deletions groupid.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ type groupID [8]byte

var rander = rand.Read

type RandomReader func(p []byte) (n int, err error)

func (g groupID) String() string {
return hex.EncodeToString(g[:])
}
Expand All @@ -22,6 +24,10 @@ func (g groupID) Size() uint16 {
}

func newGroupID(size uint16) groupID {
return newGroupIDWithReader(size, rander)
}

func newGroupIDWithReader(size uint16, rander RandomReader) groupID {
id := make([]byte, 8)
binary.NativeEndian.PutUint16(id, size)
_, _ = rander(id[2:])
Expand Down
23 changes: 20 additions & 3 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package expr

import (
"context"
"crypto/sha256"
"encoding/binary"
"fmt"
"math/rand"
"strconv"
"strings"

Expand Down Expand Up @@ -52,6 +55,8 @@ func NewTreeParser(ep CELParser) (TreeParser, error) {

type parser struct {
ep CELParser

rander RandomReader
}

func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, error) {
Expand All @@ -60,13 +65,25 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression,
return nil, issues.Err()
}

if p.rander == nil {
// Create a new deterministic random reader based off of the evaluable's identifier.
// This means that every time we parse an expression with the given identifier, the
// group IDs will be deterministic as the randomness is sourced from the ID.
//
// We only overwrite this if rander is not nil so that we can inject rander during tests.
digest := sha256.Sum256([]byte(eval.Identifier()))
seed := int64(binary.NativeEndian.Uint64(digest[:8]))
p.rander = rand.New(rand.NewSource(seed)).Read
}

node := newNode()
_, err := navigateAST(
expr{
ast: ast.NativeRep().Expr(),
},
node,
vars,
p.rander,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -327,7 +344,7 @@ type expr struct {
// It does this by iterating through the expression, amending the current `group` until
// an or expression is found. When an or expression is found, we create another group which
// is mutated by the iteration.
func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) {
func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([]*Node, error) {
// on the very first call to navigateAST, ensure that we set the first node
// inside the nodemap.
result := []*Node{}
Expand Down Expand Up @@ -376,7 +393,7 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) {
newParent := newNode()

// For each item in the stack, recurse into that AST.
_, err := navigateAST(or, newParent, vars)
_, err := navigateAST(or, newParent, vars, rand)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -428,7 +445,7 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) {
total += 1
}

parent.GroupID = newGroupID(uint16(total))
parent.GroupID = newGroupIDWithReader(uint16(total), rand)
// For each sub-group, add the same group IDs to children if there's no nesting.
for n, item := range parent.Ands {
if len(item.Ands) == 0 && len(item.Ors) == 0 && item.Predicate != nil {
Expand Down
11 changes: 7 additions & 4 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ func TestParse(t *testing.T) {
t.Helper()

for _, test := range tests {
parser, err := newParser()
p, err := newParser()
p.(*parser).rander = rander
require.NoError(t, err)

eval := tex(test.input)
actual, err := parser.Parse(ctx, eval)
actual, err := p.Parse(ctx, eval)

require.NotNil(t, actual.Root.GroupID)

Expand Down Expand Up @@ -1092,10 +1093,12 @@ func TestParse_LiftedVars(t *testing.T) {
t.Helper()

for _, test := range tests {
parser, err := NewTreeParser(cachingCelParser)
p, err := NewTreeParser(cachingCelParser)
// overwrite rander so that the parser uses the same nil bytes
p.(*parser).rander = rander
require.NoError(t, err)
eval := tex(test.input)
actual, err := parser.Parse(ctx, eval)
actual, err := p.Parse(ctx, eval)

// Shortcut to ensure the evaluable instance matches
if test.expected.Evaluable == nil {
Expand Down

0 comments on commit 3014841

Please sign in to comment.