diff --git a/cel/env.go b/cel/env.go index a8650c4e..ab736b77 100644 --- a/cel/env.go +++ b/cel/env.go @@ -459,6 +459,12 @@ func (e *Env) ParseSource(src Source) (*Ast, *Issues) { // Program generates an evaluable instance of the Ast within the environment (Env). func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) { + return e.PlanProgram(ast.NativeRep(), opts...) +} + +// PlanProgram generates an evaluable instance of the AST in the go-native representation within +// the environment (Env). +func (e *Env) PlanProgram(a *celast.AST, opts ...ProgramOption) (Program, error) { optSet := e.progOpts if len(opts) != 0 { mergedOpts := []ProgramOption{} @@ -466,7 +472,7 @@ func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) { mergedOpts = append(mergedOpts, opts...) optSet = mergedOpts } - return newProgram(e, ast, optSet) + return newProgram(e, a, optSet) } // CELTypeAdapter returns the `types.Adapter` configured for the environment. diff --git a/cel/optimizer.go b/cel/optimizer.go index ec02773a..c149abb7 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -211,6 +211,16 @@ type OptimizerContext struct { *Issues } +// ExtendEnv auguments the context's environment with the additional options. +func (opt *OptimizerContext) ExtendEnv(opts ...EnvOption) error { + e, err := opt.Env.Extend(opts...) + if err != nil { + return err + } + opt.Env = e + return nil +} + // ASTOptimizer applies an optimization over an AST and returns the optimized result. type ASTOptimizer interface { // Optimize optimizes a type-checked AST within an Environment and accumulates any issues. diff --git a/cel/program.go b/cel/program.go index 4d34305c..6f477afc 100644 --- a/cel/program.go +++ b/cel/program.go @@ -19,6 +19,7 @@ import ( "fmt" "sync" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/interpreter" @@ -151,7 +152,7 @@ func (p *prog) clone() *prog { // ProgramOption values. // // If the program cannot be configured the prog will be nil, with a non-nil error response. -func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { +func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) { // Build the dispatcher, interpreter, and default program value. disp := interpreter.NewDispatcher() @@ -255,9 +256,9 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { return p.initInterpretable(a, decorators) } -func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) { +func (p *prog) initInterpretable(a *ast.AST, decs []interpreter.InterpretableDecorator) (*prog, error) { // When the AST has been exprAST it contains metadata that can be used to speed up program execution. - interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...) + interpretable, err := p.interpreter.NewInterpretable(a, decs...) if err != nil { return nil, err } diff --git a/conformance/BUILD.bazel b/conformance/BUILD.bazel index 23bc5800..a9630d4f 100644 --- a/conformance/BUILD.bazel +++ b/conformance/BUILD.bazel @@ -32,6 +32,7 @@ _ALL_TESTS = [ "@dev_cel_expr//tests/simple:testdata/timestamps.textproto", "@dev_cel_expr//tests/simple:testdata/unknowns.textproto", "@dev_cel_expr//tests/simple:testdata/wrappers.textproto", + "@dev_cel_expr//tests/simple:testdata/block_ext.textproto", ] _TESTS_TO_SKIP = [ @@ -68,6 +69,7 @@ go_test( deps = [ "//cel:go_default_library", "//common:go_default_library", + "//common/ast:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "//ext:go_default_library", diff --git a/conformance/conformance_test.go b/conformance/conformance_test.go index 8653d010..8a1384f5 100644 --- a/conformance/conformance_test.go +++ b/conformance/conformance_test.go @@ -13,6 +13,7 @@ import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/ext" @@ -89,6 +90,7 @@ func init() { ext.Math(), ext.Protos(), ext.Strings(), + cel.Lib(celBlockLib{}), } var err error @@ -279,3 +281,89 @@ func TestConformance(t *testing.T) { } } } + +type celBlockLib struct{} + +func (celBlockLib) LibraryName() string { + return "cel.lib.ext.cel.block.conformance" +} + +func (celBlockLib) CompileOptions() []cel.EnvOption { + // Simulate indexed arguments which would normally have strong types associated + // with the values as part of a static optimization pass + maxIndices := 30 + indexOpts := make([]cel.EnvOption, maxIndices) + for i := 0; i < maxIndices; i++ { + indexOpts[i] = cel.Variable(fmt.Sprintf("@index%d", i), cel.DynType) + } + return append([]cel.EnvOption{ + cel.Macros( + // cel.block([args], expr) + cel.ReceiverMacro("block", 2, celBlock), + // cel.index(int) + cel.ReceiverMacro("index", 1, celIndex), + // cel.iterVar(int, int) + cel.ReceiverMacro("iterVar", 2, celCompreVar("cel.iterVar", "@it")), + // cel.accuVar(int, int) + cel.ReceiverMacro("accuVar", 2, celCompreVar("cel.accuVar", "@ac")), + ), + }, indexOpts...) +} + +func (celBlockLib) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +func celBlock(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + if !isCELNamespace(target) { + return nil, nil + } + bindings := args[0] + if bindings.Kind() != ast.ListKind { + return bindings, mef.NewError(bindings.ID(), "cel.block requires the first arg to be a list literal") + } + return mef.NewCall("cel.@block", args...), nil +} + +func celIndex(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + if !isCELNamespace(target) { + return nil, nil + } + index := args[0] + if !isNonNegativeInt(index) { + return index, mef.NewError(index.ID(), "cel.index requires a single non-negative int constant arg") + } + indexVal := index.AsLiteral().(types.Int) + return mef.NewIdent(fmt.Sprintf("@index%d", indexVal)), nil +} + +func celCompreVar(funcName, varPrefix string) cel.MacroFactory { + return func(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { + if !isCELNamespace(target) { + return nil, nil + } + depth := args[0] + if !isNonNegativeInt(depth) { + return depth, mef.NewError(depth.ID(), fmt.Sprintf("%s requires two non-negative int constant args", funcName)) + } + unique := args[1] + if !isNonNegativeInt(unique) { + return unique, mef.NewError(unique.ID(), fmt.Sprintf("%s requires two non-negative int constant args", funcName)) + } + depthVal := depth.AsLiteral().(types.Int) + uniqueVal := unique.AsLiteral().(types.Int) + return mef.NewIdent(fmt.Sprintf("%s:%d:%d", varPrefix, depthVal, uniqueVal)), nil + } +} + +func isCELNamespace(target ast.Expr) bool { + return target.Kind() == ast.IdentKind && target.AsIdent() == "cel" +} + +func isNonNegativeInt(expr ast.Expr) bool { + if expr.Kind() != ast.LiteralKind { + return false + } + val := expr.AsLiteral() + return val.Type() == cel.IntType && val.(types.Int) >= 0 +} diff --git a/ext/bindings.go b/ext/bindings.go index 2c6cc627..50cf4fb3 100644 --- a/ext/bindings.go +++ b/ext/bindings.go @@ -15,9 +15,19 @@ package ext import ( + "errors" + "fmt" + "math" + "strconv" + "strings" + "sync" + "github.com/google/cel-go/cel" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter" ) // Bindings returns a cel.EnvOption to configure support for local variable @@ -41,35 +51,120 @@ import ( // [d, e, f].exists(elem, elem in valid_values)) // // Local bindings are not guaranteed to be evaluated before use. -func Bindings() cel.EnvOption { - return cel.Lib(celBindings{}) +func Bindings(options ...BindingsOption) cel.EnvOption { + b := &celBindings{version: math.MaxUint32} + for _, o := range options { + b = o(b) + } + return cel.Lib(b) } const ( celNamespace = "cel" bindMacro = "bind" + blockFunc = "@block" unusedIterVar = "#unused" ) -type celBindings struct{} +// BindingsOption declares a functional operator for configuring the Bindings library behavior. +type BindingsOption func(*celBindings) *celBindings + +// BindingsVersion sets the version of the bindings library to an explicit version. +func BindingsVersion(version uint32) BindingsOption { + return func(lib *celBindings) *celBindings { + lib.version = version + return lib + } +} + +type celBindings struct { + version uint32 +} -func (celBindings) LibraryName() string { +func (*celBindings) LibraryName() string { return "cel.lib.ext.cel.bindings" } -func (celBindings) CompileOptions() []cel.EnvOption { - return []cel.EnvOption{ +func (lib *celBindings) CompileOptions() []cel.EnvOption { + opts := []cel.EnvOption{ cel.Macros( // cel.bind(var, , ) cel.ReceiverMacro(bindMacro, 3, celBind), ), } + if lib.version >= 1 { + // The cel.@block signature takes a list of subexpressions and a typed expression which is + // used as the output type. + paramType := cel.TypeParamType("T") + opts = append(opts, + cel.Function("cel.@block", + cel.Overload("cel_block_list", + []*cel.Type{cel.ListType(cel.DynType), paramType}, paramType)), + ) + opts = append(opts, cel.ASTValidators(blockValidationExemption{})) + } + return opts } -func (celBindings) ProgramOptions() []cel.ProgramOption { +func (lib *celBindings) ProgramOptions() []cel.ProgramOption { + if lib.version >= 1 { + celBlockPlan := func(i interpreter.Interpretable) (interpreter.Interpretable, error) { + call, ok := i.(interpreter.InterpretableCall) + if !ok { + return i, nil + } + switch call.Function() { + case "cel.@block": + args := call.Args() + if len(args) != 2 { + return nil, fmt.Errorf("cel.@block expects two arguments, but got %d", len(args)) + } + expr := args[1] + // Non-empty block + if block, ok := args[0].(interpreter.InterpretableConstructor); ok { + slotExprs := block.InitVals() + return newDynamicBlock(slotExprs, expr), nil + } + // Constant valued block which can happen during runtime optimization. + if cons, ok := args[0].(interpreter.InterpretableConst); ok { + if cons.Value().Type() == types.ListType { + l := cons.Value().(traits.Lister) + if l.Size().Equal(types.IntZero) == types.True { + return args[1], nil + } + return newConstantBlock(l, expr), nil + } + } + return nil, errors.New("cel.@block expects a list constructor as the first argument") + default: + return i, nil + } + } + return []cel.ProgramOption{cel.CustomDecorator(celBlockPlan)} + } return []cel.ProgramOption{} } +type blockValidationExemption struct{} + +// Name returns the name of the validator. +func (blockValidationExemption) Name() string { + return "cel.lib.ext.validate.functions.cel.block" +} + +// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip +// during homogeneous aggregate literal type-checks. +func (blockValidationExemption) Configure(config cel.MutableValidatorConfig) error { + functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string) + functions = append(functions, "cel.@block") + return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions) +} + +// Validate is a no-op as the intent is to simply disable strong type-checks for list literals during +// when they occur within cel.@block calls as the arg types have already been validated. +func (blockValidationExemption) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) { +} + func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { if !macroTargetMatchesNamespace(celNamespace, target) { return nil, nil @@ -94,3 +189,148 @@ func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Ex resultExpr, ), nil } + +func newDynamicBlock(slotExprs []interpreter.Interpretable, expr interpreter.Interpretable) interpreter.Interpretable { + bs := &dynamicBlock{ + slotExprs: slotExprs, + expr: expr, + } + bs.slotActivationPool = &sync.Pool{ + New: func() any { + slotCount := len(slotExprs) + sa := &dynamicSlotActivation{ + slotExprs: slotExprs, + slotCount: slotCount, + slotVals: make([]*slotVal, slotCount), + } + for i := 0; i < slotCount; i++ { + sa.slotVals[i] = &slotVal{} + } + return sa + }, + } + return bs +} + +type dynamicBlock struct { + slotExprs []interpreter.Interpretable + expr interpreter.Interpretable + slotActivationPool *sync.Pool +} + +// ID implements the Interpretable interface method. +func (b *dynamicBlock) ID() int64 { + return b.expr.ID() +} + +// Eval implements the Interpretable interface method. +func (b *dynamicBlock) Eval(activation interpreter.Activation) ref.Val { + sa := b.slotActivationPool.Get().(*dynamicSlotActivation) + sa.Activation = activation + defer b.clearSlots(sa) + return b.expr.Eval(sa) +} + +func (b *dynamicBlock) clearSlots(sa *dynamicSlotActivation) { + sa.reset() + b.slotActivationPool.Put(sa) +} + +type slotVal struct { + value *ref.Val + visited bool +} + +type dynamicSlotActivation struct { + interpreter.Activation + slotExprs []interpreter.Interpretable + slotCount int + slotVals []*slotVal +} + +// ResolveName implements the Activation interface method but handles variables prefixed with `@index` +// as special variables which exist within the slot-based memory of the cel.@block() where each slot +// refers to an expression which must be computed only once. +func (sa *dynamicSlotActivation) ResolveName(name string) (any, bool) { + if idx, found := matchSlot(name, sa.slotCount); found { + v := sa.slotVals[idx] + if v.visited { + // Return not found if the index expression refers to itself + if v.value == nil { + return nil, false + } + return *v.value, true + } + v.visited = true + val := sa.slotExprs[idx].Eval(sa) + v.value = &val + return val, true + } + return sa.Activation.ResolveName(name) +} + +func (sa *dynamicSlotActivation) reset() { + sa.Activation = nil + for _, sv := range sa.slotVals { + sv.visited = false + sv.value = nil + } +} + +func newConstantBlock(slots traits.Lister, expr interpreter.Interpretable) interpreter.Interpretable { + count := slots.Size().(types.Int) + return &constantBlock{slots: slots, slotCount: int(count), expr: expr} +} + +type constantBlock struct { + slots traits.Lister + slotCount int + expr interpreter.Interpretable +} + +// ID implements the interpreter.Interpretable interface method. +func (b *constantBlock) ID() int64 { + return b.expr.ID() +} + +// Eval implements the interpreter.Interpretable interface method, and will proxy @index prefixed variable +// lookups into a set of constant slots determined from the plan step. +func (b *constantBlock) Eval(activation interpreter.Activation) ref.Val { + vars := constantSlotActivation{Activation: activation, slots: b.slots, slotCount: b.slotCount} + return b.expr.Eval(vars) +} + +type constantSlotActivation struct { + interpreter.Activation + slots traits.Lister + slotCount int +} + +// ResolveName implements Activation interface method and proxies @index prefixed lookups into the slot +// activation associated with the block scope. +func (sa constantSlotActivation) ResolveName(name string) (any, bool) { + if idx, found := matchSlot(name, sa.slotCount); found { + return sa.slots.Get(types.Int(idx)), true + } + return sa.Activation.ResolveName(name) +} + +func matchSlot(name string, slotCount int) (int, bool) { + if idx, found := strings.CutPrefix(name, indexPrefix); found { + idx, err := strconv.Atoi(idx) + // Return not found if the index is not numeric + if err != nil { + return -1, false + } + // Return not found if the index is not a valid slot + if idx < 0 || idx >= slotCount { + return -1, false + } + return idx, true + } + return -1, false +} + +var ( + indexPrefix = "@index" +) diff --git a/ext/bindings_test.go b/ext/bindings_test.go index 850e7a4a..bd6ecf7b 100644 --- a/ext/bindings_test.go +++ b/ext/bindings_test.go @@ -20,22 +20,26 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" ) var bindingTests = []struct { expr string parseOnly bool }{ - {expr: `cel.bind(a, 'hell' + 'o' + '!', [a, a, a].join(', ')) == + {expr: `cel.bind(a, 'hell' + 'o' + '!', [a, a, a].join(', ')) == ['hell' + 'o' + '!', 'hell' + 'o' + '!', 'hell' + 'o' + '!'].join(', ')`}, // Variable shadowing - {expr: `cel.bind(a, - cel.bind(a, 'world', a + '!'), + {expr: `cel.bind(a, + cel.bind(a, 'world', a + '!'), 'hello ' + a) == 'hello ' + 'world' + '!'`}, } func TestBindings(t *testing.T) { - env, err := cel.NewEnv(Bindings(), Strings()) + env, err := cel.NewEnv(Bindings(BindingsVersion(0)), Strings()) if err != nil { t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err) } @@ -125,3 +129,258 @@ func BenchmarkBindings(b *testing.B) { }) } } + +func TestBlockEval(t *testing.T) { + fac := ast.NewExprFactory() + tests := []struct { + name string + expr ast.Expr + opts []cel.EnvOption + in map[string]any + out ref.Val + }{ + { + name: "chained block", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewIdent(3, "x"), + fac.NewIdent(4, "@index0"), + fac.NewIdent(5, "@index1"), + }, []int32{}), + fac.NewCall(9, operators.Add, + fac.NewCall(6, operators.Add, + fac.NewIdent(7, "@index2"), + fac.NewIdent(8, "@index1")), + fac.NewIdent(10, "@index0"), + ), + ), + opts: []cel.EnvOption{ + cel.Variable("x", cel.StringType), + }, + in: map[string]any{"x": "hello"}, + out: types.String("hellohellohello"), + }, + { + name: "empty block", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{}, []int32{}), + fac.NewCall(3, operators.LogicalNot, fac.NewLiteral(4, types.False)), + ), + in: map[string]any{}, + out: types.True, + }, + { + name: "mixed block constant values", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewLiteral(3, types.String("hello")), + fac.NewLiteral(4, types.Int(5)), + }, []int32{}), + fac.NewCall(5, operators.Equals, + fac.NewCall(6, "size", + fac.NewIdent(7, "@index0")), + fac.NewIdent(8, "@index1"), + ), + ), + opts: []cel.EnvOption{ + cel.ExtendedValidations(), + }, + in: map[string]any{}, + out: types.True, + }, + { + name: "mixed block dynamic values", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewIdent(3, "x"), + fac.NewLiteral(4, types.Int(5)), + }, []int32{}), + fac.NewCall(5, operators.Equals, + fac.NewCall(6, "size", + fac.NewIdent(7, "@index0")), + fac.NewIdent(8, "@index1"), + ), + ), + opts: []cel.EnvOption{ + cel.Variable("x", cel.StringType), + cel.ExtendedValidations(), + }, + in: map[string]any{"x": "goodbye"}, + out: types.False, + }, + { + name: "mixed block constant values dyn var", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewLiteral(3, types.String("hello")), + }, []int32{}), + fac.NewCall(4, operators.Equals, + fac.NewCall(5, "size", + fac.NewIdent(6, "@index0")), + fac.NewIdent(7, "y"), + ), + ), + opts: []cel.EnvOption{ + cel.Variable("y", cel.IntType), + cel.ExtendedValidations(), + }, + in: map[string]any{ + "y": 5, + }, + out: types.True, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + blockAST := ast.NewAST(tc.expr, nil) + opts := append([]cel.EnvOption{Bindings()}, tc.opts...) + env, err := cel.NewEnv(opts...) + if err != nil { + t.Fatalf("cel.NewEnv(Bindings()) failed: %v", err) + } + prg, err := env.PlanProgram(blockAST, cel.EvalOptions(cel.OptOptimize)) + if err != nil { + t.Fatalf("PlanProgram() failed: %v", err) + } + out, _, err := prg.Eval(tc.in) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out.Equal(tc.out) != types.True { + t.Errorf("got %v, wanted %v", out, tc.out) + } + }) + } +} + +func TestBlockEval_BadPlan(t *testing.T) { + fac := ast.NewExprFactory() + blockExpr := fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewIdent(3, "x"), + fac.NewIdent(4, "@index0"), + }, []int32{}), + fac.NewCall(6, operators.Add, + fac.NewIdent(7, "@index1"), + fac.NewIdent(8, "@index0")), + fac.NewIdent(9, "x"), + ) + blockAST := ast.NewAST(blockExpr, nil) + env, err := cel.NewEnv( + Bindings(BindingsVersion(1)), + cel.Variable("x", cel.StringType), + ) + if err != nil { + t.Fatalf("cel.NewEnv(Bindings()) failed: %v", err) + } + _, err = env.PlanProgram(blockAST) + if err == nil { + t.Fatal("PlanProgram() succeeded, expected error") + } +} + +func TestBlockEval_BadBlock(t *testing.T) { + fac := ast.NewExprFactory() + blockExpr := fac.NewCall( + 1, "cel.@block", + fac.NewCall(2, operators.Add, + fac.NewIdent(3, "@index1"), + fac.NewIdent(4, "@index0")), + fac.NewIdent(5, "x"), + ) + blockAST := ast.NewAST(blockExpr, nil) + env, err := cel.NewEnv( + Bindings(BindingsVersion(1)), + cel.Variable("x", cel.StringType), + ) + if err != nil { + t.Fatalf("cel.NewEnv(Bindings()) failed: %v", err) + } + _, err = env.PlanProgram(blockAST) + if err == nil { + t.Fatal("PlanProgram() succeeded, expected error") + } +} + +func TestBlockEval_RuntimeErrors(t *testing.T) { + fac := ast.NewExprFactory() + tests := []struct { + name string + expr ast.Expr + }{ + { + name: "bad index", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewIdent(3, "x"), + fac.NewIdent(4, "@indexNext"), + }, []int32{}), + fac.NewCall(6, operators.Add, + fac.NewIdent(7, "@indexNext"), + fac.NewIdent(8, "@index0")), + ), + }, + { + name: "infinite recursion", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewIdent(3, "@index0"), + fac.NewIdent(4, "@index0"), + }, []int32{}), + fac.NewIdent(10, "@index0"), + ), + }, + { + name: "negative index", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewIdent(3, "@index-1"), + fac.NewIdent(4, "@index0"), + }, []int32{}), + fac.NewIdent(10, "@index0"), + ), + }, + { + name: "out of range index", + expr: fac.NewCall( + 1, "cel.@block", + fac.NewList(2, []ast.Expr{ + fac.NewIdent(3, "@index100"), + fac.NewIdent(4, "@index0"), + }, []int32{}), + fac.NewIdent(10, "@index0"), + ), + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + blockAST := ast.NewAST(tc.expr, nil) + env, err := cel.NewEnv( + Bindings(BindingsVersion(1)), + cel.Variable("x", cel.StringType), + ) + if err != nil { + t.Fatalf("cel.NewEnv(Bindings()) failed: %v", err) + } + prg, err := env.PlanProgram(blockAST) + if err != nil { + t.Fatalf("PlanProgram() failed: %v", err) + } + _, _, err = prg.Eval(map[string]any{"x": "hello"}) + if !strings.Contains(err.Error(), "no such attribute") { + t.Fatalf("prg.Eval() got %v, expected no such attribute error", err) + } + }) + } +} diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 61167c45..ebc432e9 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -97,7 +97,7 @@ type InterpretableCall interface { Args() []Interpretable } -// InterpretableConstructor interface for inspecting Interpretable instructions that initialize a list, map +// InterpretableConstructor interface for inspecting Interpretable instructions that initialize a list, map // or struct. type InterpretableConstructor interface { Interpretable diff --git a/policy/compiler_test.go b/policy/compiler_test.go index 923b9933..5865f524 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -25,6 +25,7 @@ import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/ext" "github.com/google/cel-go/interpreter" ) @@ -159,7 +160,8 @@ func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel. cel.DefaultUTCTimeZone(true), cel.OptionalTypes(), cel.EnableMacroCallTracking(), - cel.ExtendedValidations()) + cel.ExtendedValidations(), + ext.Bindings()) if err != nil { t.Fatalf("cel.NewEnv() failed: %v", err) } diff --git a/policy/composer.go b/policy/composer.go index 022f6a7e..84f3f6a5 100644 --- a/policy/composer.go +++ b/policy/composer.go @@ -15,6 +15,9 @@ package policy import ( + "fmt" + "strings" + "github.com/google/cel-go/cel" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" @@ -39,12 +42,23 @@ type RuleComposer struct { // Compose stitches together a set of expressions within a CompiledRule into a single CEL ast. func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { ruleRoot, _ := c.env.Compile("true") - opt := cel.NewStaticOptimizer(&ruleComposerImpl{rule: r}) + opt := cel.NewStaticOptimizer(&ruleComposerImpl{rule: r, varIndices: []varIndex{}}) return opt.Optimize(c.env, ruleRoot) } +type varIndex struct { + index int + indexVar string + localVar string + expr ast.Expr + cv *CompiledVariable +} + type ruleComposerImpl struct { - rule *CompiledRule + rule *CompiledRule + nextVarIndex int + varIndices []varIndex + maxNestedExpressionLimit int } @@ -54,14 +68,34 @@ func (opt *ruleComposerImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *as // The input to optimize is a dummy expression which is completely replaced according // to the configuration of the rule composition graph. ruleExpr := opt.optimizeRule(ctx, opt.rule) - return ctx.NewAST(ruleExpr) + allVars := opt.sortedVariables() + // If there were no variables, return the expression. + if len(allVars) == 0 { + return ctx.NewAST(ruleExpr) + } + + // Otherwise populate the block. + varExprs := make([]ast.Expr, len(allVars)) + for i, vi := range allVars { + varExprs[i] = vi.expr + err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.cv.Declaration().Type())) + if err != nil { + ctx.ReportErrorAtID(ruleExpr.ID(), err.Error()) + } + } + blockExpr := ctx.NewCall("cel.@block", ctx.NewList(varExprs, []int32{}), ruleExpr) + return ctx.NewAST(blockExpr) } func (opt *ruleComposerImpl) optimizeRule(ctx *cel.OptimizerContext, r *CompiledRule) ast.Expr { matchExpr := ctx.NewCall("optional.none") matches := r.Matches() matchCount := len(matches) + // Visitor to rewrite variables-prefixed identifiers with index names. vars := r.Variables() + for _, v := range vars { + opt.registerVariable(ctx, v) + } optionalResult := true // Build the rule subgraph. @@ -121,17 +155,43 @@ func (opt *ruleComposerImpl) optimizeRule(ctx *cel.OptimizerContext, r *Compiled ) } - // Bind variables in reverse order to declaration on top of rule-subgraph. - for i := len(vars) - 1; i >= 0; i-- { - v := vars[i] - varAST := ctx.CopyASTAndMetadata(v.Expr().NativeRep()) - // Build up the bindings in reverse order, starting from root, all the way up to the outermost - // binding: - // currExpr = cel.bind(outerVar, outerExpr, currExpr) - varName := v.Declaration().Name() - inlined, bindMacro := ctx.NewBindMacro(matchExpr.ID(), varName, varAST, matchExpr) - ctx.UpdateExpr(matchExpr, inlined) - ctx.SetMacroCall(matchExpr.ID(), bindMacro) - } + identVisitor := opt.rewriteVariableName(ctx) + ast.PostOrderVisit(matchExpr, identVisitor) + return matchExpr } + +func (opt *ruleComposerImpl) rewriteVariableName(ctx *cel.OptimizerContext) ast.Visitor { + return ast.NewExprVisitor(func(expr ast.Expr) { + if expr.Kind() != ast.IdentKind || !strings.HasPrefix(expr.AsIdent(), "variables.") { + return + } + varName := expr.AsIdent() + for i := len(opt.varIndices) - 1; i >= 0; i-- { + v := opt.varIndices[i] + if v.localVar == varName { + ctx.UpdateExpr(expr, ctx.NewIdent(v.indexVar)) + return + } + } + }) +} + +func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) { + varName := fmt.Sprintf("variables.%s", v.Name()) + indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex) + varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep()) + ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx)) + vi := varIndex{ + index: opt.nextVarIndex, + indexVar: indexVar, + localVar: varName, + expr: varExpr, + cv: v} + opt.varIndices = append(opt.varIndices, vi) + opt.nextVarIndex++ +} + +func (opt *ruleComposerImpl) sortedVariables() []varIndex { + return opt.varIndices +} diff --git a/policy/helper_test.go b/policy/helper_test.go index 00dbd53d..d681fb9a 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -43,48 +43,48 @@ var ( return p, nil }}, expr: ` - cel.bind(variables.env, resource.labels.?environment.orValue("prod"), - cel.bind(variables.break_glass, resource.labels.?break_glass.orValue("false") == "true", - !(variables.break_glass || - resource.containers.all(c, c.startsWith(variables.env + "."))) - ? optional.of("only %s containers are allowed in namespace %s".format([variables.env, resource.namespace])) - : optional.none()))`, + cel.@block([ + resource.labels.?environment.orValue("prod"), + resource.labels.?break_glass.orValue("false") == "true"], + !(@index1 || resource.containers.all(c, c.startsWith(@index0 + "."))) + ? optional.of("only %s containers are allowed in namespace %s".format([@index0, resource.namespace])) + : optional.none())`, }, { name: "nested_rule", expr: ` - cel.bind(variables.permitted_regions, ["us", "uk", "es"], - cel.bind(variables.banned_regions, {"us": false, "ru": false, "ir": false}, - (resource.origin in variables.banned_regions && - !(resource.origin in variables.permitted_regions)) - ? optional.of({"banned": true}) : optional.none()).or( - optional.of((resource.origin in variables.permitted_regions) - ? {"banned": false} : {"banned": true})))`, + cel.@block([ + ["us", "uk", "es"], + {"us": false, "ru": false, "ir": false}], + ((resource.origin in @index1 && !(resource.origin in @index0)) + ? optional.of({"banned": true}) : optional.none()).or( + optional.of((resource.origin in @index0) + ? {"banned": false} : {"banned": true})))`, }, { name: "nested_rule2", expr: ` - cel.bind(variables.permitted_regions, ["us", "uk", "es"], + cel.@block([ + ["us", "uk", "es"], + {"us": false, "ru": false, "ir": false}], resource.?user.orValue("").startsWith("bad") - ? cel.bind(variables.banned_regions, {"us": false, "ru": false, "ir": false}, - (resource.origin in variables.banned_regions && - !(resource.origin in variables.permitted_regions)) - ? {"banned": "restricted_region"} : {"banned": "bad_actor"}) - : (!(resource.origin in variables.permitted_regions) + ? ((resource.origin in @index1 && !(resource.origin in @index0)) + ? {"banned": "restricted_region"} + : {"banned": "bad_actor"}) + : (!(resource.origin in @index0) ? {"banned": "unconfigured_region"} : {}))`, }, { name: "nested_rule3", expr: ` - cel.bind(variables.permitted_regions, ["us", "uk", "es"], + cel.@block([ + ["us", "uk", "es"], + {"us": false, "ru": false, "ir": false}], resource.?user.orValue("").startsWith("bad") - ? optional.of( - cel.bind(variables.banned_regions, {"us": false, "ru": false, "ir": false}, - (resource.origin in variables.banned_regions && - !(resource.origin in variables.permitted_regions)) - ? {"banned": "restricted_region"} : {"banned": "bad_actor"})) - : (!(resource.origin in variables.permitted_regions) - ? optional.of({"banned": "unconfigured_region"}) : optional.none()))`, + ? optional.of((resource.origin in @index1 && !(resource.origin in @index0)) + ? {"banned": "restricted_region"} : {"banned": "bad_actor"}) + : (!(resource.origin in @index0) + ? optional.of({"banned": "unconfigured_region"}) : optional.none()))`, }, { name: "context_pb", @@ -115,34 +115,27 @@ var ( { name: "required_labels", expr: ` - cel.bind(variables.want, spec.labels, - cel.bind(variables.missing, variables.want.filter(l, !(l in resource.labels)), - cel.bind(variables.invalid, - resource.labels.filter(l, l in variables.want && - variables.want[l] != resource.labels[l]), - (variables.missing.size() > 0) - ? optional.of("missing one or more required labels: %s".format([variables.missing])) - : ((variables.invalid.size() > 0) - ? optional.of("invalid values provided on one or more labels: %s".format([variables.invalid])) : optional.none()))))`, + cel.@block([ + spec.labels, + @index0.filter(l, !(l in resource.labels)), + resource.labels.filter(l, l in @index0 && @index0[l] != resource.labels[l])], + (@index1.size() > 0) + ? optional.of("missing one or more required labels: %s".format([@index1])) + : ((@index2.size() > 0) + ? optional.of("invalid values provided on one or more labels: %s".format([@index2])) + : optional.none()))`, }, { name: "restricted_destinations", expr: ` - cel.bind(variables.matches_origin_ip, + cel.@block([ locationCode(origin.ip) == spec.origin, - cel.bind(variables.has_nationality, has(request.auth.claims.nationality), - cel.bind(variables.matches_nationality, - variables.has_nationality && request.auth.claims.nationality == spec.origin, - cel.bind(variables.matches_dest_ip, - locationCode(destination.ip) in spec.restricted_destinations, - cel.bind(variables.matches_dest_label, - resource.labels.location in spec.restricted_destinations, - cel.bind(variables.matches_dest, - variables.matches_dest_ip || variables.matches_dest_label, - (variables.matches_nationality && variables.matches_dest) - ? true - : ((!variables.has_nationality && variables.matches_origin_ip && variables.matches_dest) - ? true : false)))))))`, + has(request.auth.claims.nationality), + @index1 && request.auth.claims.nationality == spec.origin, + locationCode(destination.ip) in spec.restricted_destinations, + resource.labels.location in spec.restricted_destinations, + @index3 || @index4], + (@index2 && @index5) ? true : ((!@index1 && @index0 && @index5) ? true : false))`, envOpts: []cel.EnvOption{ cel.Function("locationCode", cel.Overload("locationCode_string", []*cel.Type{cel.StringType}, cel.StringType, @@ -161,21 +154,21 @@ var ( { name: "limits", expr: ` - cel.bind(variables.greeting, "hello", - cel.bind(variables.farewell, "goodbye", - cel.bind(variables.person, "me", - cel.bind(variables.message_fmt, "%s, %s", + cel.@block([ + "hello", + "goodbye", + "me", + "%s, %s", + @index3.format([@index1, @index2])], (now.getHours() >= 20) - ? cel.bind(variables.message, variables.message_fmt.format([variables.farewell, variables.person]), - (now.getHours() < 21) - ? optional.of(variables.message + "!") - : ((now.getHours() < 22) - ? optional.of(variables.message + "!!") - : ((now.getHours() < 24) - ? optional.of(variables.message + "!!!") - : optional.none()))) - : optional.of(variables.message_fmt.format([variables.greeting, variables.person])) - ))))`, + ? ((now.getHours() < 21) + ? optional.of(@index4 + "!") + : ((now.getHours() < 22) + ? optional.of(@index4 + "!!") + : ((now.getHours() < 24) + ? optional.of(@index4 + "!!!") + : optional.none()))) + : optional.of(@index3.format([@index0, @index2])))`, }, }