diff --git a/interpreter/activation.go b/interpreter/activation.go index a8026445..1577f359 100644 --- a/interpreter/activation.go +++ b/interpreter/activation.go @@ -17,7 +17,6 @@ package interpreter import ( "errors" "fmt" - "sync" "github.com/google/cel-go/common/types/ref" ) @@ -167,35 +166,3 @@ type partActivation struct { func (a *partActivation) UnknownAttributePatterns() []*AttributePattern { return a.unknowns } - -// varActivation represents a single mutable variable binding. -// -// This activation type should only be used within folds as the fold loop controls the object -// life-cycle. -type varActivation struct { - parent Activation - name string - val ref.Val -} - -// Parent implements the Activation interface method. -func (v *varActivation) Parent() Activation { - return v.parent -} - -// ResolveName implements the Activation interface method. -func (v *varActivation) ResolveName(name string) (any, bool) { - if name == v.name { - return v.val, true - } - return v.parent.ResolveName(name) -} - -var ( - // pool of var activations to reduce allocations during folds. - varActivationPool = &sync.Pool{ - New: func() any { - return &varActivation{} - }, - } -) diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index 56123840..61167c45 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -16,6 +16,7 @@ package interpreter import ( "fmt" + "sync" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/operators" @@ -720,24 +721,31 @@ func (o *evalObj) Eval(ctx Activation) ref.Val { return types.LabelErrNode(o.id, o.provider.NewValue(o.typeName, fieldVals)) } +// InitVals implements the InterpretableConstructor interface method. func (o *evalObj) InitVals() []Interpretable { return o.vals } +// Type implements the InterpretableConstructor interface method. func (o *evalObj) Type() ref.Type { - return types.NewObjectTypeValue(o.typeName) + return types.NewObjectType(o.typeName) } type evalFold struct { - id int64 - accuVar string - iterVar string - iterRange Interpretable - accu Interpretable - cond Interpretable - step Interpretable - result Interpretable - adapter types.Adapter + id int64 + accuVar string + iterVar string + iterVar2 string + iterRange Interpretable + accu Interpretable + cond Interpretable + step Interpretable + result Interpretable + adapter types.Adapter + + // note an exhaustive fold will ensure that all branches are evaluated + // when using mutable values, these branches will mutate the final result + // rather than make a throw-away computation. exhaustive bool interruptable bool } @@ -749,64 +757,30 @@ func (fold *evalFold) ID() int64 { // Eval implements the Interpretable interface method. func (fold *evalFold) Eval(ctx Activation) ref.Val { - foldRange := fold.iterRange.Eval(ctx) - if !foldRange.Type().HasTrait(traits.IterableType) { - return types.ValOrErr(foldRange, "got '%T', expected iterable type", foldRange) - } - // Configure the fold activation with the accumulator initial value. - accuCtx := varActivationPool.Get().(*varActivation) - accuCtx.parent = ctx - accuCtx.name = fold.accuVar - accuCtx.val = fold.accu.Eval(ctx) - // If the accumulator starts as an empty list, then the comprehension will build a list - // so create a mutable list to optimize the cost of the inner loop. - l, ok := accuCtx.val.(traits.Lister) - buildingList := false - if !fold.exhaustive && ok && l.Size() == types.IntZero { - buildingList = true - accuCtx.val = types.NewMutableList(fold.adapter) - } - iterCtx := varActivationPool.Get().(*varActivation) - iterCtx.parent = accuCtx - iterCtx.name = fold.iterVar - - interrupted := false - it := foldRange.(traits.Iterable).Iterator() - for it.HasNext() == types.True { - // Modify the iter var in the fold activation. - iterCtx.val = it.Next() + // Initialize the folder interface + f := newFolder(fold, ctx) + defer releaseFolder(f) - // Evaluate the condition, terminate the loop if false. - cond := fold.cond.Eval(iterCtx) - condBool, ok := cond.(types.Bool) - if !fold.exhaustive && ok && condBool != types.True { - break - } - // Evaluate the evaluation step into accu var. - accuCtx.val = fold.step.Eval(iterCtx) - if fold.interruptable { - if stop, found := ctx.ResolveName("#interrupted"); found && stop == true { - interrupted = true - break - } + foldRange := fold.iterRange.Eval(ctx) + if fold.iterVar2 != "" { + var foldable traits.Foldable + switch r := foldRange.(type) { + case traits.Mapper: + foldable = types.ToFoldableMap(r) + case traits.Lister: + foldable = types.ToFoldableList(r) + default: + return types.NewErrWithNodeID(fold.ID(), "unsupported comprehension range type: %T", foldRange) } - } - varActivationPool.Put(iterCtx) - if interrupted { - varActivationPool.Put(accuCtx) - return types.NewErr("operation interrupted") + foldable.Fold(f) + return f.evalResult() } - // Compute the result. - res := fold.result.Eval(accuCtx) - varActivationPool.Put(accuCtx) - // Convert a mutable list to an immutable one, if the comprehension has generated a list as a result. - if !types.IsUnknownOrError(res) && buildingList { - if _, ok := res.(traits.MutableLister); ok { - res = res.(traits.MutableLister).ToImmutableList() - } + if !foldRange.Type().HasTrait(traits.IterableType) { + return types.ValOrErr(foldRange, "got '%T', expected iterable type", foldRange) } - return res + iterable := foldRange.(traits.Iterable) + return f.foldIterable(iterable) } // Optional Interpretable implementations that specialize, subsume, or extend the core evaluation @@ -1262,3 +1236,172 @@ func invalidOptionalEntryInit(field any, value ref.Val) ref.Val { func invalidOptionalElementInit(value ref.Val) ref.Val { return types.NewErr("cannot initialize optional list element from non-optional value %v", value) } + +// newFolder creates or initializes a pooled folder instance. +func newFolder(eval *evalFold, ctx Activation) *folder { + f := folderPool.Get().(*folder) + f.evalFold = eval + f.Activation = ctx + return f +} + +// releaseFolder resets and releases a pooled folder instance. +func releaseFolder(f *folder) { + f.reset() + folderPool.Put(f) +} + +// folder tracks the state associated with folding a list or map with a comprehension v2 style macro. +// +// The folder embeds an interpreter.Activation and Interpretable evalFold value as well as implements +// the traits.Folder interface methods. +// +// Instances of a folder are intended to be pooled to minimize allocation overhead with this temporary +// bookkeeping object which supports lazy evaluation of the accumulator init expression which is useful +// in preserving evaluation order semantics which might otherwise be disrupted through the use of +// cel.bind or cel.@block. +type folder struct { + *evalFold + Activation + + // fold state objects. + accuVal ref.Val + iterVar1Val any + iterVar2Val any + + // bookkeeping flags to modify Activation and fold behaviors. + initialized bool + mutableValue bool + interrupted bool + computeResult bool +} + +func (f *folder) foldIterable(iterable traits.Iterable) ref.Val { + it := iterable.Iterator() + for it.HasNext() == types.True { + f.iterVar1Val = it.Next() + + cond := f.cond.Eval(f) + condBool, ok := cond.(types.Bool) + if f.interrupted || (!f.exhaustive && ok && condBool != types.True) { + return f.evalResult() + } + + // Update the accumulation value and check for eval interuption. + f.accuVal = f.step.Eval(f) + f.initialized = true + if f.interruptable && checkInterrupt(f.Activation) { + f.interrupted = true + return f.evalResult() + } + } + return f.evalResult() +} + +// FoldEntry will either fold comprehension v1 style macros if iterVar2 is unset, or comprehension v2 style +// macros if both the iterVar and iterVar2 are set to non-empty strings. +func (f *folder) FoldEntry(key, val any) bool { + // Default to referencing both values. + f.iterVar1Val = key + f.iterVar2Val = val + + // Terminate evaluation if evaluation is interrupted or the condition is not true and exhaustive + // eval is not enabled. + cond := f.cond.Eval(f) + condBool, ok := cond.(types.Bool) + if f.interrupted || (!f.exhaustive && ok && condBool != types.True) { + return false + } + + // Update the accumulation value and check for eval interuption. + f.accuVal = f.step.Eval(f) + f.initialized = true + if f.interruptable && checkInterrupt(f.Activation) { + f.interrupted = true + return false + } + return true +} + +// ResolveName overrides the default Activation lookup to perform lazy initialization of the accumulator +// and specialized lookups of iteration values with consideration for whether the final result is being +// computed and the iteration variables should be ignored. +func (f *folder) ResolveName(name string) (any, bool) { + if name == f.accuVar { + if !f.initialized { + f.initialized = true + initVal := f.accu.Eval(f.Activation) + if !f.exhaustive { + if l, isList := initVal.(traits.Lister); isList && l.Size() == types.IntZero { + initVal = types.NewMutableList(f.adapter) + f.mutableValue = true + } + if m, isMap := initVal.(traits.Mapper); isMap && m.Size() == types.IntZero { + initVal = types.NewMutableMap(f.adapter, map[ref.Val]ref.Val{}) + f.mutableValue = true + } + } + f.accuVal = initVal + } + return f.accuVal, true + } + if !f.computeResult { + if name == f.iterVar { + f.iterVar1Val = f.adapter.NativeToValue(f.iterVar1Val) + return f.iterVar1Val, true + } + if name == f.iterVar2 { + f.iterVar2Val = f.adapter.NativeToValue(f.iterVar2Val) + return f.iterVar2Val, true + } + } + return f.Activation.ResolveName(name) +} + +// evalResult computes the final result of the fold after all entries have been folded and accumulated. +func (f *folder) evalResult() ref.Val { + f.computeResult = true + if f.interrupted { + return types.NewErr("operation interrupted") + } + res := f.result.Eval(f) + // Convert a mutable list or map to an immutable one if the comprehension has generated a list or + // map as a result. + if !types.IsUnknownOrError(res) && f.mutableValue { + if _, ok := res.(traits.MutableLister); ok { + res = res.(traits.MutableLister).ToImmutableList() + } + if _, ok := res.(traits.MutableMapper); ok { + res = res.(traits.MutableMapper).ToImmutableMap() + } + } + return res +} + +// reset clears any state associated with folder evaluation. +func (f *folder) reset() { + f.evalFold = nil + f.Activation = nil + f.accuVal = nil + f.iterVar1Val = nil + f.iterVar2Val = nil + + f.initialized = false + f.mutableValue = false + f.interrupted = false + f.computeResult = false +} + +func checkInterrupt(a Activation) bool { + stop, found := a.ResolveName("#interrupted") + return found && stop == true +} + +var ( + // pool of var folders to reduce allocations during folds. + folderPool = &sync.Pool{ + New: func() any { + return &folder{} + }, + } +) diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index ad470973..00bf04dc 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -33,6 +33,7 @@ import ( "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/functions" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/stdlib" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -1941,6 +1942,99 @@ func TestInterpreter_PlanOptionalElements(t *testing.T) { } } +func TestInterpreter_PlanListComprehensionTwoVar(t *testing.T) { + fac := ast.NewExprFactory() + listTwoArgTuples := fac.NewComprehensionTwoVar(1, + fac.NewList(2, []ast.Expr{ + fac.NewLiteral(3, types.Int(2)), + fac.NewLiteral(4, types.Int(3)), + }, []int32{}), + "i", + "v", + "__result__", + fac.NewList(5, []ast.Expr{}, []int32{}), + fac.NewLiteral(6, types.True), + fac.NewCall(7, operators.Add, fac.NewAccuIdent(8), + fac.NewList(9, []ast.Expr{fac.NewIdent(10, "i"), fac.NewIdent(11, "v")}, []int32{})), + fac.NewAccuIdent(12), + ) + cont := containers.DefaultContainer + reg := newTestRegistry(t) + attrs := NewAttributeFactory(cont, reg, reg) + interp := newStandardInterpreter(t, cont, reg, reg, attrs) + expr, err := interp.NewInterpretable(ast.NewAST(listTwoArgTuples, nil), Optimize()) + if err != nil { + t.Fatalf("interp.NewInterpretable() failed for two-variable comprehension: %v", err) + } + result := expr.Eval(EmptyActivation()) + if types.IsError(result) { + t.Fatalf("expr.Eval() yielded error: %v", result) + } + want := []int64{0, 2, 1, 3} + out, err := result.ConvertToNative(reflect.TypeOf(want)) + if err != nil { + t.Fatalf("result.ConvertToNative() failed: %v", err) + } + if !reflect.DeepEqual(out, want) { + t.Errorf("got %v, wanted %v", out, want) + } +} + +func TestInterpreter_PlanMapComprehensionTwoVar(t *testing.T) { + fac := ast.NewExprFactory() + listTwoArgTuples := fac.NewComprehensionTwoVar(1, + fac.NewMap(2, []ast.EntryExpr{ + fac.NewMapEntry(3, fac.NewLiteral(4, types.Int(0)), fac.NewLiteral(5, types.String("first")), false), + fac.NewMapEntry(6, fac.NewLiteral(7, types.Int(1)), fac.NewLiteral(8, types.String("second")), false), + }), + "k", + "v", + "__result__", + fac.NewMap(9, []ast.EntryExpr{}), + fac.NewLiteral(10, types.True), + fac.NewCall(11, "cel.@mapInsert", + fac.NewAccuIdent(12), + fac.NewCall(13, operators.Add, fac.NewIdent(14, "k"), fac.NewLiteral(15, types.IntOne)), + fac.NewIdent(16, "v"), + ), + fac.NewAccuIdent(17), + ) + cont := containers.DefaultContainer + reg := newTestRegistry(t) + attrs := NewAttributeFactory(cont, reg, reg) + interp := newStandardInterpreter(t, cont, reg, reg, attrs, + funcDecl(t, "cel.@mapInsert", + decls.Overload("cel.@mapInsert", + []*types.Type{ + types.NewMapType(types.IntType, types.StringType), + types.IntType, + types.StringType, + }, types.NewMapType(types.IntType, types.StringType)), + decls.SingletonFunctionBinding(func(args ...ref.Val) ref.Val { + m := args[0].(traits.Mapper) + k := args[1] + v := args[2] + return types.InsertMapKeyValue(m, k, v) + }), + )) + expr, err := interp.NewInterpretable(ast.NewAST(listTwoArgTuples, nil), Optimize()) + if err != nil { + t.Fatalf("interp.NewInterpretable() failed for two-variable comprehension: %v", err) + } + result := expr.Eval(EmptyActivation()) + if types.IsError(result) { + t.Fatalf("expr.Eval() yielded error: %v", result) + } + want := map[int64]string{1: "first", 2: "second"} + out, err := result.ConvertToNative(reflect.TypeOf(want)) + if err != nil { + t.Fatalf("result.ConvertToNative() failed: %v", err) + } + if !reflect.DeepEqual(out, want) { + t.Errorf("got %v, wanted %v", out, want) + } +} + func testContainer(name string) *containers.Container { cont, _ := containers.NewContainer(containers.Name(name)) return cont @@ -2124,11 +2218,22 @@ func newStandardInterpreter(t *testing.T, container *containers.Container, provider types.Provider, adapter types.Adapter, - resolver AttributeFactory) Interpreter { + resolver AttributeFactory, + optFuncs ...*decls.FunctionDecl) Interpreter { t.Helper() - dispatcher := NewDispatcher() - addFunctionBindings(t, dispatcher) - return NewInterpreter(dispatcher, container, provider, adapter, resolver) + disp := NewDispatcher() + addFunctionBindings(t, disp) + for _, fn := range optFuncs { + bindings, err := fn.Bindings() + if err != nil { + t.Fatalf("fn.Bindings() failed for function %v. error: %v", fn.Name(), err) + } + err = disp.Add(bindings...) + if err != nil { + t.Fatalf("dispatcher.Add() failed: %v", err) + } + } + return NewInterpreter(disp, container, provider, adapter, resolver) } func addFunctionBindings(t testing.TB, dispatcher Dispatcher) { diff --git a/interpreter/planner.go b/interpreter/planner.go index cf371f95..3d918ce8 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -603,6 +603,7 @@ func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { accuVar: fold.AccuVar(), accu: accu, iterVar: fold.IterVar(), + iterVar2: fold.IterVar2(), iterRange: iterRange, cond: cond, step: step,