Skip to content

Commit

Permalink
Set membership test rewriting optimizer (#865)
Browse files Browse the repository at this point in the history
* Set membership test rewriting optimizer
* Additional tests for macros and enums
  • Loading branch information
TristonianJones authored Dec 5, 2023
1 parent 967fca9 commit 52e5dcc
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 1 deletion.
1 change: 1 addition & 0 deletions ext/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ go_library(
"//checker:go_default_library",
"//common/ast:go_default_library",
"//common/overloads:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
Expand Down
64 changes: 64 additions & 0 deletions ext/sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
Expand Down Expand Up @@ -119,6 +121,68 @@ func (setsLib) ProgramOptions() []cel.ProgramOption {
}
}

// NewSetMembershipOptimizer rewrites set membership tests using the `in` operator against a list
// of constant values of enum, int, uint, string, or boolean type into a set membership test against
// a map where the map keys are the elements of the list.
func NewSetMembershipOptimizer() (cel.ASTOptimizer, error) {
return setsLib{}, nil
}

func (setsLib) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
root := ast.NavigateAST(a)
matches := ast.MatchDescendants(root, matchInConstantList(a))
for _, match := range matches {
call := match.AsCall()
listArg := call.Args()[1]
entries := make([]ast.EntryExpr, len(listArg.AsList().Elements()))
for i, elem := range listArg.AsList().Elements() {
var entry ast.EntryExpr
if r, found := a.ReferenceMap()[elem.ID()]; found && r.Value != nil {
entry = ctx.NewMapEntry(ctx.NewLiteral(r.Value), ctx.NewLiteral(types.True), false)
} else {
entry = ctx.NewMapEntry(elem, ctx.NewLiteral(types.True), false)
}
entries[i] = entry
}
mapArg := ctx.NewMap(entries)
ctx.UpdateExpr(listArg, mapArg)
}
return a
}

func matchInConstantList(a *ast.AST) ast.ExprMatcher {
return func(e ast.NavigableExpr) bool {
if e.Kind() != ast.CallKind {
return false
}
call := e.AsCall()
if call.FunctionName() != operators.In {
return false
}
aggregateVal := call.Args()[1]
if aggregateVal.Kind() != ast.ListKind {
return false
}
listVal := aggregateVal.AsList()
for _, elem := range listVal.Elements() {
if r, found := a.ReferenceMap()[elem.ID()]; found {
if r.Value != nil {
continue
}
}
if elem.Kind() != ast.LiteralKind {
return false
}
lit := elem.AsLiteral()
if !(lit.Type() == cel.StringType || lit.Type() == cel.IntType ||
lit.Type() == cel.UintType || lit.Type() == cel.BoolType) {
return false
}
}
return true
}
}

func setsIntersects(listA, listB ref.Val) ref.Val {
lA := listA.(traits.Lister)
lB := listB.(traits.Lister)
Expand Down
170 changes: 169 additions & 1 deletion ext/sets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import (

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/test/proto3pb"
)

func TestSets(t *testing.T) {
Expand Down Expand Up @@ -328,9 +331,174 @@ func TestSets(t *testing.T) {
}
}

func TestSetsMembershipRewriter(t *testing.T) {
tests := []struct {
expr string
optimized string
opts []cel.EnvOption
in map[string]any
out ref.Val
}{
{
expr: `a in [1, 2, 3, 4]`,
optimized: `a in {1: true, 2: true, 3: true, 4: true}`,
opts: []cel.EnvOption{
cel.Variable("a", cel.IntType),
},
in: map[string]any{
"a": 3,
},
out: types.True,
},
{
expr: `a in ['1', '2', '3', 4]`,
optimized: `a in {"1": true, "2": true, "3": true, 4: true}`,
opts: []cel.EnvOption{
cel.Variable("a", cel.IntType),
},
in: map[string]any{
"a": 3,
},
out: types.False,
},
{
expr: `a in [1u, '2', '3', 4]`,
optimized: `a in {1u: true, "2": true, "3": true, 4: true}`,
opts: []cel.EnvOption{
cel.Variable("a", cel.IntType),
},
in: map[string]any{
"a": 4,
},
out: types.True,
},
{
expr: `a in [1u, 2.0, '3', 4]`,
optimized: `a in [1u, 2.0, "3", 4]`,
opts: []cel.EnvOption{
cel.Variable("a", cel.IntType),
},
in: map[string]any{
"a": 4,
},
out: types.True,
},
{
expr: `a in [b, 32]`,
optimized: `a in [b, 32]`,
opts: []cel.EnvOption{
cel.Variable("a", cel.IntType),
cel.Variable("b", cel.IntType),
},
in: map[string]any{
"a": 4,
"b": 4,
},
out: types.True,
},
{
expr: `a in {b: c}`,
optimized: `a in {b: c}`,
opts: []cel.EnvOption{
cel.Variable("a", cel.IntType),
cel.Variable("b", cel.IntType),
cel.Variable("c", cel.IntType),
},
in: map[string]any{
"a": 4,
"b": 42,
"c": 123,
},
out: types.False,
},
{
expr: `a in {3: true}`,
optimized: `a in {3: true}`,
opts: []cel.EnvOption{
cel.Variable("a", cel.IntType),
},
in: map[string]any{
"a": 4,
},
out: types.False,
},
{
expr: `a in ["hello", "world"].map(i, i in ["goodbye", "world"], i + i)`,
optimized: `a in ["hello", "world"].map(i, i in {"goodbye": true, "world": true}, i + i)`,
opts: []cel.EnvOption{
cel.Variable("a", cel.StringType),
},
in: map[string]any{
"a": "worldworld",
},
out: types.True,
},
{
expr: `a in [test.GlobalEnum.GOO, test.GlobalEnum.GAR, test.GlobalEnum.GAZ]`,
optimized: `a in {0: true, 1: true, 2: true}`,
opts: []cel.EnvOption{
cel.Container("google.expr.proto3"),
cel.Variable("a", cel.IntType),
cel.Types(&proto3pb.TestAllTypes{}),
},
in: map[string]any{
"a": proto3pb.GlobalEnum_GAZ,
},
out: types.True,
},
}
for _, tst := range tests {
tc := tst
t.Run(tc.expr, func(t *testing.T) {
env := testSetsEnv(t, tc.opts...)
var asts []*cel.Ast
a, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err())
}
asts = append(asts, a)
setsOpt, err := NewSetMembershipOptimizer()
if err != nil {
t.Fatalf("NewSetMembershipOptimizer() failed with error: %v", err)
}
opt := cel.NewStaticOptimizer(setsOpt)
optAST, iss := opt.Optimize(env, a)
if iss.Err() != nil {
t.Fatalf("opt.Optimize() failed: %v", iss.Err())
}
optExpr, err := cel.AstToString(optAST)
if err != nil {
t.Fatalf("cel.AstToString() failed :%v", err)
}
if tc.optimized != optExpr {
t.Errorf("got %v, wanted optimized expr %v", optExpr, tc.optimized)
}
asts = append(asts, optAST)

for _, ast := range asts {
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
in := tc.in
if in == nil {
in = map[string]any{}
}
out, _, err := prg.Eval(in)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != tc.out {
t.Errorf("prg.Eval() got %v, wanted %v for expr: %s", out, tc.out, tc.expr)
}
}
})
}
}

func testSetsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
t.Helper()
baseOpts := []cel.EnvOption{Sets()}
baseOpts := []cel.EnvOption{cel.EnableMacroCallTracking(), Sets()}
env, err := cel.NewEnv(append(baseOpts, opts...)...)
if err != nil {
t.Fatalf("cel.NewEnv(Sets()) failed: %v", err)
Expand Down

0 comments on commit 52e5dcc

Please sign in to comment.