From 67e5aeb8fb32a25efedeef777a9bc9fb6cd155a5 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Wed, 25 Oct 2023 20:01:05 -0700 Subject: [PATCH 1/8] Sets cost estimation and tracking options --- cel/cel_test.go | 10 +- cel/env.go | 10 +- cel/options.go | 18 ++ cel/program.go | 30 ++- checker/cost.go | 47 ++++- checker/cost_test.go | 78 +++++--- ext/sets.go | 58 +++++- ext/sets_test.go | 337 +++++++++++++++++++++++++++----- interpreter/runtimecost.go | 50 +++-- interpreter/runtimecost_test.go | 28 +++ 10 files changed, 559 insertions(+), 107 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 0f8e36a1..4addeb76 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1475,7 +1475,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { name string expr string decls []EnvOption - hints map[string]int64 + hints map[string]uint64 want checker.CostEstimate in any }{ @@ -1499,7 +1499,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { Variable("str1", StringType), Variable("str2", StringType), }, - hints: map[string]int64{"str1": 10, "str2": 10}, + hints: map[string]uint64{"str1": 10, "str2": 10}, want: checker.CostEstimate{Min: 2, Max: 6}, in: map[string]any{"str1": "val1111111", "str2": "val2222222"}, }, @@ -1510,7 +1510,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() if tc.hints == nil { - tc.hints = map[string]int64{} + tc.hints = map[string]uint64{} } env := testEnv(t, tc.decls...) ast, iss := env.Compile(tc.expr) @@ -2768,12 +2768,12 @@ func BenchmarkDynamicDispatch(b *testing.B) { // TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package type testCostEstimator struct { - hints map[string]int64 + hints map[string]uint64 } func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { - return &checker.SizeEstimate{Min: 0, Max: uint64(l)} + return &checker.SizeEstimate{Min: 0, Max: l} } return nil } diff --git a/cel/env.go b/cel/env.go index 786a13c4..82c31683 100644 --- a/cel/env.go +++ b/cel/env.go @@ -129,6 +129,7 @@ type Env struct { appliedFeatures map[int]bool libraries map[string]bool validators []ASTValidator + costOptions []checker.CostOption // Internal parser representation prsr *parser.Parser @@ -191,6 +192,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) { libraries: map[string]bool{}, validators: []ASTValidator{}, progOpts: []ProgramOption{}, + costOptions: []checker.CostOption{}, }).configure(opts) } @@ -365,6 +367,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { } validatorsCopy := make([]ASTValidator, len(e.validators)) copy(validatorsCopy, e.validators) + costOptsCopy := make([]checker.CostOption, len(e.costOptions)) + copy(costOptsCopy, e.costOptions) ext := &Env{ Container: e.Container, @@ -380,6 +384,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { provider: provider, chkOpts: chkOptsCopy, prsrOpts: prsrOptsCopy, + costOptions: costOptsCopy, } return ext.configure(opts) } @@ -556,7 +561,10 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { // EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and // extension functions provided by estimator. func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) { - return checker.Cost(ast.impl, estimator, opts...) + extendedOpts := make([]checker.CostOption, 0, len(e.costOptions)) + extendedOpts = append(extendedOpts, opts...) + extendedOpts = append(extendedOpts, e.costOptions...) + return checker.Cost(ast.impl, estimator, extendedOpts...) } // configure applies a series of EnvOptions to the current environment. diff --git a/cel/options.go b/cel/options.go index 74e3b0ba..89825c36 100644 --- a/cel/options.go +++ b/cel/options.go @@ -23,6 +23,7 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/dynamicpb" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/types" @@ -471,6 +472,16 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption { } } +// CostTrackerOptions configures a set of options for cost-tracking. +// +// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled. +func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption { + return func(p *prog) (*prog, error) { + p.costOptions = append(p.costOptions, costOpts...) + return p, nil + } +} + // CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls. func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption { return func(p *prog) (*prog, error) { @@ -630,6 +641,13 @@ func ParserExpressionSizeLimit(limit int) EnvOption { } } +func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption { + return func(e *Env) (*Env, error) { + e.costOptions = append(e.costOptions, costOpts...) + return e, nil + } +} + func maybeInteropProvider(provider any) (types.Provider, error) { switch p := provider.(type) { case types.Provider: diff --git a/cel/program.go b/cel/program.go index cec4839d..50fb120f 100644 --- a/cel/program.go +++ b/cel/program.go @@ -105,7 +105,7 @@ func (ed *EvalDetails) State() interpreter.EvalState { // ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled. // Otherwise, returns nil if the cost was not enabled. func (ed *EvalDetails) ActualCost() *uint64 { - if ed.costTracker == nil { + if ed == nil || ed.costTracker == nil { return nil } cost := ed.costTracker.ActualCost() @@ -129,10 +129,14 @@ type prog struct { // Interpretable configured from an Ast and aggregate decorator set based on program options. interpretable interpreter.Interpretable callCostEstimator interpreter.ActualCostEstimator + costOptions []interpreter.CostTrackerOption costLimit *uint64 } func (p *prog) clone() *prog { + costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions)) + copy(costOptsCopy, p.costOptions) + return &prog{ Env: p.Env, evalOpts: p.evalOpts, @@ -154,9 +158,10 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { // Ensure the default attribute factory is set after the adapter and provider are // configured. p := &prog{ - Env: e, - decorators: []interpreter.InterpretableDecorator{}, - dispatcher: disp, + Env: e, + decorators: []interpreter.InterpretableDecorator{}, + dispatcher: disp, + costOptions: []interpreter.CostTrackerOption{}, } // Configure the program via the ProgramOption values. @@ -213,6 +218,12 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) { costTracker.Estimator = p.callCostEstimator costTracker.Limit = p.costLimit + for _, costOpt := range p.costOptions { + err := costOpt(costTracker) + if err != nil { + return nil, err + } + } // Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This // prevents the underlying memory from being shared between factory function calls causing // undesired mutations. @@ -325,7 +336,11 @@ type progGen struct { // the test is successful. func newProgGen(factory progFactory) (Program, error) { // Test the factory to make sure that configuration errors are spotted at config - _, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{}) + tracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, err + } + _, err = factory(interpreter.NewEvalState(), tracker) if err != nil { return nil, err } @@ -338,7 +353,10 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) { // new EvalState instance for each call to ensure that unique evaluations yield unique stateful // results. state := interpreter.NewEvalState() - costTracker := &interpreter.CostTracker{} + costTracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, nil, err + } det := &EvalDetails{state: state, costTracker: costTracker} // Generate a new instance of the interpretable using the factory configured during the call to diff --git a/checker/cost.go b/checker/cost.go index b6109d91..c8e21297 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -15,6 +15,7 @@ package checker import ( + "fmt" "math" "github.com/google/cel-go/common" @@ -256,9 +257,10 @@ type coster struct { // iterRanges tracks the iterRange of each iterVar. iterRanges iterRangeScopes // computedSizes tracks the computed sizes of call results. - computedSizes map[int64]SizeEstimate - checkedAST *ast.AST - estimator CostEstimator + computedSizes map[int64]SizeEstimate + checkedAST *ast.AST + estimator CostEstimator + functionEstimators map[string]FunctionEstimator // presenceTestCost will either be a zero or one based on whether has() macros count against cost computations. presenceTestCost CostEstimate } @@ -287,6 +289,7 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) { type CostOption func(*coster) error // PresenceTestHasCost determines whether presence testing has a cost of one or zero. +// // Defaults to presence test has a cost of one. func PresenceTestHasCost(hasCost bool) CostOption { return func(c *coster) error { @@ -299,15 +302,31 @@ func PresenceTestHasCost(hasCost bool) CostOption { } } +// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair. +type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate + +// FunctionCostEstimate binds a FunctionCoster to a specific function, overload pair. +// +// When a FunctionCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to +// the Cost() call. +func FunctionCostEstimate(function, overloadID string, functionCoster FunctionEstimator) CostOption { + return func(c *coster) error { + functionKey := fmt.Sprintf("%s|%s", function, overloadID) + c.functionEstimators[functionKey] = functionCoster + return nil + } +} + // Cost estimates the cost of the parsed and type checked CEL expression. func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) { c := &coster{ - checkedAST: checked, - estimator: estimator, - exprPath: map[int64][]string{}, - iterRanges: map[string][]int64{}, - computedSizes: map[int64]SizeEstimate{}, - presenceTestCost: CostEstimate{Min: 1, Max: 1}, + checkedAST: checked, + estimator: estimator, + functionEstimators: map[string]FunctionEstimator{}, + exprPath: map[int64][]string{}, + iterRanges: map[string][]int64{}, + computedSizes: map[int64]SizeEstimate{}, + presenceTestCost: CostEstimate{Min: 1, Max: 1}, } for _, opt := range opts { err := opt(c) @@ -518,7 +537,15 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args } return sum } - + if len(c.functionEstimators) != 0 { + functionKey := fmt.Sprintf("%s|%s", function, overloadID) + if estimator, found := c.functionEstimators[functionKey]; found { + if est := estimator(c.estimator, target, args); est != nil { + callEst := *est + return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize} + } + } + } if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil { callEst := *est return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize} diff --git a/checker/cost_test.go b/checker/cost_test.go index 92f98d84..55d852d8 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -15,6 +15,7 @@ package checker import ( + "math" "strings" "testing" @@ -44,7 +45,7 @@ func TestCost(t *testing.T) { name string expr string vars []*decls.VariableDecl - hints map[string]int64 + hints map[string]uint64 options []CostOption wanted CostEstimate }{ @@ -129,14 +130,14 @@ func TestCost(t *testing.T) { { name: "all comprehension", vars: []*decls.VariableDecl{decls.NewVariable("input", allList)}, - hints: map[string]int64{"input": 100}, + hints: map[string]uint64{"input": 100}, expr: `input.all(x, true)`, wanted: CostEstimate{Min: 2, Max: 302}, }, { name: "nested all comprehension", vars: []*decls.VariableDecl{decls.NewVariable("input", nestedList)}, - hints: map[string]int64{"input": 50, "input.@items": 10}, + hints: map[string]uint64{"input": 50, "input.@items": 10}, expr: `input.all(x, x.all(y, true))`, wanted: CostEstimate{Min: 2, Max: 1752}, }, @@ -148,7 +149,7 @@ func TestCost(t *testing.T) { { name: "variable cost function", vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, expr: `input.matches('[0-9]')`, wanted: CostEstimate{Min: 3, Max: 103}, }, @@ -257,14 +258,14 @@ func TestCost(t *testing.T) { { name: "bytes to string conversion", vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, expr: `string(input)`, wanted: CostEstimate{Min: 1, Max: 51}, }, { name: "bytes to string conversion equality", vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, // equality check ensures that the resultSize calculation is included in cost expr: `string(input) == string(input)`, wanted: CostEstimate{Min: 3, Max: 152}, @@ -272,14 +273,14 @@ func TestCost(t *testing.T) { { name: "string to bytes conversion", vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, expr: `bytes(input)`, wanted: CostEstimate{Min: 1, Max: 51}, }, { name: "string to bytes conversion equality", vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)}, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, // equality check ensures that the resultSize calculation is included in cost expr: `bytes(input) == bytes(input)`, wanted: CostEstimate{Min: 3, Max: 302}, @@ -296,7 +297,7 @@ func TestCost(t *testing.T) { decls.NewVariable("input", types.StringType), decls.NewVariable("arg1", types.StringType), }, - hints: map[string]int64{"input": 500, "arg1": 500}, + hints: map[string]uint64{"input": 500, "arg1": 500}, wanted: CostEstimate{Min: 2, Max: 2502}, }, { @@ -305,7 +306,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", types.StringType), }, - hints: map[string]int64{"input": 500}, + hints: map[string]uint64{"input": 500}, wanted: CostEstimate{Min: 3, Max: 103}, }, { @@ -315,7 +316,7 @@ func TestCost(t *testing.T) { decls.NewVariable("input", types.StringType), decls.NewVariable("arg1", types.StringType), }, - hints: map[string]int64{"arg1": 500}, + hints: map[string]uint64{"arg1": 500}, wanted: CostEstimate{Min: 2, Max: 52}, }, { @@ -325,7 +326,7 @@ func TestCost(t *testing.T) { decls.NewVariable("input", types.StringType), decls.NewVariable("arg1", types.StringType), }, - hints: map[string]int64{"arg1": 500}, + hints: map[string]uint64{"arg1": 500}, wanted: CostEstimate{Min: 2, Max: 52}, }, { @@ -352,7 +353,7 @@ func TestCost(t *testing.T) { decls.NewVariable("input1", allList), decls.NewVariable("input2", allList), }, - hints: map[string]int64{"input1": 1, "input2": 1}, + hints: map[string]uint64{"input1": 1, "input2": 1}, wanted: CostEstimate{Min: 4, Max: 7}, }, { @@ -361,7 +362,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", allMap), }, - hints: map[string]int64{"input": 10}, + hints: map[string]uint64{"input": 10}, wanted: CostEstimate{Min: 2, Max: 82}, }, { @@ -370,7 +371,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", nestedMap), }, - hints: map[string]int64{"input": 5, "input.@values": 10}, + hints: map[string]uint64{"input": 5, "input.@values": 10}, wanted: CostEstimate{Min: 2, Max: 187}, }, { @@ -379,7 +380,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", nestedMap), }, - hints: map[string]int64{"input": 5, "input.@keys": 10}, + hints: map[string]uint64{"input": 5, "input.@keys": 10}, wanted: CostEstimate{Min: 2, Max: 32}, }, { @@ -388,7 +389,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", nestedMap), }, - hints: map[string]int64{"input": 2, "input.@values": 2, "input.@keys": 5}, + hints: map[string]uint64{"input": 2, "input.@values": 2, "input.@keys": 5}, wanted: CostEstimate{Min: 2, Max: 34}, }, { @@ -397,7 +398,7 @@ func TestCost(t *testing.T) { vars: []*decls.VariableDecl{ decls.NewVariable("input", nestedMap), }, - hints: map[string]int64{"input": 2, "input.@values": 2, "input.@keys": 5}, + hints: map[string]uint64{"input": 2, "input.@values": 2, "input.@keys": 5}, wanted: CostEstimate{Min: 2, Max: 34}, }, { @@ -407,7 +408,7 @@ func TestCost(t *testing.T) { decls.NewVariable("list1", types.NewListType(types.IntType)), decls.NewVariable("list2", types.NewListType(types.IntType)), }, - hints: map[string]int64{"list1": 10, "list2": 10}, + hints: map[string]uint64{"list1": 10, "list2": 10}, wanted: CostEstimate{Min: 4, Max: 64}, }, { @@ -417,9 +418,30 @@ func TestCost(t *testing.T) { decls.NewVariable("str1", types.StringType), decls.NewVariable("str2", types.StringType), }, - hints: map[string]int64{"str1": 10, "str2": 10}, + hints: map[string]uint64{"str1": 10, "str2": 10}, wanted: CostEstimate{Min: 2, Max: 6}, }, + { + name: "str concat custom cost estimate", + expr: `"abcdefg".contains(str1 + str2)`, + vars: []*decls.VariableDecl{ + decls.NewVariable("str1", types.StringType), + decls.NewVariable("str2", types.StringType), + }, + hints: map[string]uint64{"str1": 10, "str2": 10}, + options: []CostOption{ + FunctionCostEstimate(overloads.Contains, overloads.ContainsString, + func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate { + if target != nil && len(args) == 1 { + strSize := estimateSize(estimator, *target).MultiplyByCostFactor(0.2) + subSize := estimateSize(estimator, args[0]).MultiplyByCostFactor(0.2) + return &CallEstimate{CostEstimate: strSize.Multiply(subSize)} + } + return nil + }), + }, + wanted: CostEstimate{Min: 2, Max: 12}, + }, { name: "list size comparison", expr: `list1.size() == list2.size()`, @@ -486,7 +508,7 @@ func TestCost(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { if tc.hints == nil { - tc.hints = map[string]int64{} + tc.hints = map[string]uint64{} } p, err := parser.NewParser(parser.Macros(parser.AllMacros...)) if err != nil { @@ -531,12 +553,12 @@ func TestCost(t *testing.T) { } type testCostEstimator struct { - hints map[string]int64 + hints map[string]uint64 } func (tc testCostEstimator) EstimateSize(element AstNode) *SizeEstimate { if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { - return &SizeEstimate{Min: 0, Max: uint64(l)} + return &SizeEstimate{Min: 0, Max: l} } return nil } @@ -548,3 +570,13 @@ func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target } return nil } + +func estimateSize(estimator CostEstimator, node AstNode) SizeEstimate { + if l := node.ComputedSize(); l != nil { + return *l + } + if l := estimator.EstimateSize(node); l != nil { + return *l + } + return SizeEstimate{Min: 0, Max: math.MaxUint64} +} diff --git a/ext/sets.go b/ext/sets.go index 4820d619..ea18618d 100644 --- a/ext/sets.go +++ b/ext/sets.go @@ -15,10 +15,14 @@ package ext import ( + "math" + "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter" ) // Sets returns a cel.EnvOption to configure namespaced set relationship @@ -95,12 +99,24 @@ func (setsLib) CompileOptions() []cel.EnvOption { cel.Function("sets.intersects", cel.Overload("list_sets_intersects_list", []*cel.Type{listType, listType}, cel.BoolType, cel.BinaryBinding(setsIntersects))), + cel.CostEstimatorOptions( + checker.FunctionCostEstimate("sets.contains", "list_sets_contains_list", estimateSetsCost(1)), + checker.FunctionCostEstimate("sets.intersects", "list_sets_intersects_list", estimateSetsCost(1)), + // equivalence requires potentially two m*n comparisons to ensure each list is contained by the other + checker.FunctionCostEstimate("sets.equivalent", "list_sets_equivalent_list", estimateSetsCost(2)), + ), } } // ProgramOptions implements the Library interface method. func (setsLib) ProgramOptions() []cel.ProgramOption { - return []cel.ProgramOption{} + return []cel.ProgramOption{ + cel.CostTrackerOptions( + interpreter.FunctionCostTracker("sets.contains", "list_sets_contains_list", trackSetsCost(1)), + interpreter.FunctionCostTracker("sets.intersects", "list_sets_intersects_list", trackSetsCost(1)), + interpreter.FunctionCostTracker("sets.equivalent", "list_sets_equivalent_list", trackSetsCost(2)), + ), + } } func setsIntersects(listA, listB ref.Val) ref.Val { @@ -136,3 +152,43 @@ func setsEquivalent(listA, listB ref.Val) ref.Val { } return setsContains(listB, listA) } + +func estimateSetsCost(costFactor float64) checker.FunctionEstimator { + return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + arg0Size := estimateSize(estimator, args[0]) + arg1Size := estimateSize(estimator, args[1]) + costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate) + return &checker.CallEstimate{CostEstimate: costEstimate} + } +} + +func estimateSize(estimator checker.CostEstimator, node checker.AstNode) checker.SizeEstimate { + if l := node.ComputedSize(); l != nil { + return *l + } + if l := estimator.EstimateSize(node); l != nil { + return *l + } + return checker.SizeEstimate{Min: 0, Max: math.MaxUint64} +} + +func trackSetsCost(costFactor float64) interpreter.FunctionTracker { + return func(args []ref.Val, _ ref.Val) *uint64 { + lhsSize := actualSize(args[0]) + rhsSize := actualSize(args[1]) + cost := callCost + uint64(float64(lhsSize*rhsSize)*costFactor) + return &cost + } +} + +func actualSize(value ref.Val) uint64 { + if sz, ok := value.(traits.Sizer); ok { + return uint64(sz.Size().(types.Int)) + } + return 1 +} + +var ( + callCostEstimate = checker.CostEstimate{Min: 1, Max: 1} + callCost = uint64(1) +) diff --git a/ext/sets_test.go b/ext/sets_test.go index 5fbe3b9a..1d402b27 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -15,67 +15,266 @@ package ext import ( - "fmt" + "reflect" + "strings" "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" ) func TestSets(t *testing.T) { setsTests := []struct { - expr string + expr string + vars []cel.EnvOption + in map[string]any + hints map[string]uint64 + estimatedCost checker.CostEstimate + actualCost uint64 }{ // set containment - {expr: `sets.contains([], [])`}, - {expr: `sets.contains([1], [])`}, - {expr: `sets.contains([1], [1])`}, - {expr: `sets.contains([1], [1, 1])`}, - {expr: `sets.contains([1, 1], [1])`}, - {expr: `sets.contains([2, 1], [1])`}, - {expr: `sets.contains([1, 2, 3, 4], [2, 3])`}, - {expr: `sets.contains([1], [1.0, 1])`}, - {expr: `sets.contains([1, 2], [2u, 2.0])`}, - {expr: `sets.contains([1, 2u], [2, 2.0])`}, - {expr: `sets.contains([1, 2.0, 3u], [1.0, 2u, 3])`}, - {expr: `sets.contains([[1], [2, 3]], [[2, 3.0]])`}, - {expr: `!sets.contains([1], [2])`}, - {expr: `!sets.contains([1], [1, 2])`}, - {expr: `!sets.contains([1], ["1", 1])`}, - {expr: `!sets.contains([1], [1.1, 1u])`}, - // set equivalence - {expr: `sets.equivalent([], [])`}, - {expr: `sets.equivalent([1], [1])`}, - {expr: `sets.equivalent([1], [1, 1])`}, - {expr: `sets.equivalent([1, 1], [1])`}, - {expr: `sets.equivalent([1], [1u, 1.0])`}, - {expr: `sets.equivalent([1], [1u, 1.0])`}, - {expr: `sets.equivalent([1, 2, 3], [3u, 2.0, 1])`}, - {expr: `sets.equivalent([[1.0], [2, 3]], [[1], [2, 3.0]])`}, - {expr: `!sets.equivalent([2, 1], [1])`}, - {expr: `!sets.equivalent([1], [1, 2])`}, - {expr: `!sets.equivalent([1, 2], [2u, 2, 2.0])`}, - {expr: `!sets.equivalent([1, 2], [1u, 2, 2.3])`}, + { + expr: `sets.contains(x, [1, 2, 3])`, + vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))}, + in: map[string]any{"x": []int64{5, 4, 3, 2, 1}}, + hints: map[string]uint64{"x": 10}, + // min cost is input 'x' length 0, 10 for list creation, 2 for arg costs + // max cost is input 'x' lenght 10, 10 for list creation, 2 for arg costs + estimatedCost: checker.CostEstimate{Min: 12, Max: 42}, + // actual cost is 'x' length 5 * list literal length 3, 10 for list creation, 2 for arg cost + actualCost: 27, + }, + { + expr: `sets.contains(x, [1, 1, 1, 1, 1])`, + vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))}, + in: map[string]any{"x": []int64{5, 4, 3, 2, 1}}, + // min cost is input 'x' length 0, 10 for list creation, 2 for arg costs + // max cost is effectively infinite due to missing size hint for 'x' + estimatedCost: checker.CostEstimate{Min: 12, Max: 9223372036854775820}, + // actual cost is 'x' length 5 * list literal length 5, 10 for list creation, 2 for arg cost + actualCost: 37, + }, + { + expr: `sets.contains([], [])`, + estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + actualCost: 21, + }, + { + expr: `sets.contains([1], [])`, + estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + actualCost: 21, + }, + { + expr: `sets.contains([1], [1])`, + estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + actualCost: 22, + }, + { + expr: `sets.contains([1], [1, 1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.contains([1, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.contains([2, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.contains([1, 2, 3, 4], [2, 3])`, + estimatedCost: checker.CostEstimate{Min: 29, Max: 29}, + actualCost: 29, + }, + { + expr: `sets.contains([1], [1.0, 1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.contains([1, 2], [2u, 2.0])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.contains([1, 2u], [2, 2.0])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.contains([1, 2.0, 3u], [1.0, 2u, 3])`, + estimatedCost: checker.CostEstimate{Min: 30, Max: 30}, + actualCost: 30, + }, + { + expr: `sets.contains([[1], [2, 3]], [[2, 3.0]])`, + // 10 for each list creation, top-level list sizes are 2, 1 + estimatedCost: checker.CostEstimate{Min: 53, Max: 53}, + actualCost: 53, + }, + { + expr: `!sets.contains([1], [2])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `!sets.contains([1], [1, 2])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, + { + expr: `!sets.contains([1], ["1", 1])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, + { + expr: `!sets.contains([1], [1.1, 1u])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, + + // set equivalence (note the cost factor is higher as it's basically two contains checks) + { + expr: `sets.equivalent([], [])`, + estimatedCost: checker.CostEstimate{Min: 21, Max: 21}, + actualCost: 21, + }, + { + expr: `sets.equivalent([1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.equivalent([1], [1, 1])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.equivalent([1, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.equivalent([1], [1u, 1.0])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.equivalent([1], [1u, 1.0])`, + estimatedCost: checker.CostEstimate{Min: 25, Max: 25}, + actualCost: 25, + }, + { + expr: `sets.equivalent([1, 2, 3], [3u, 2.0, 1])`, + estimatedCost: checker.CostEstimate{Min: 39, Max: 39}, + actualCost: 39, + }, + { + expr: `sets.equivalent([[1.0], [2, 3]], [[1], [2, 3.0]])`, + estimatedCost: checker.CostEstimate{Min: 69, Max: 69}, + actualCost: 69, + }, + { + expr: `!sets.equivalent([2, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 26, Max: 26}, + actualCost: 26, + }, + { + expr: `!sets.equivalent([1], [1, 2])`, + estimatedCost: checker.CostEstimate{Min: 26, Max: 26}, + actualCost: 26, + }, + { + expr: `!sets.equivalent([1, 2], [2u, 2, 2.0])`, + estimatedCost: checker.CostEstimate{Min: 34, Max: 34}, + actualCost: 34, + }, + { + expr: `!sets.equivalent([1, 2], [1u, 2, 2.3])`, + estimatedCost: checker.CostEstimate{Min: 34, Max: 34}, + actualCost: 34, + }, + // set intersection - {expr: `sets.intersects([1], [1])`}, - {expr: `sets.intersects([1], [1, 1])`}, - {expr: `sets.intersects([1, 1], [1])`}, - {expr: `sets.intersects([2, 1], [1])`}, - {expr: `sets.intersects([1], [1, 2])`}, - {expr: `sets.intersects([1], [1.0, 2])`}, - {expr: `sets.intersects([1, 2], [2u, 2, 2.0])`}, - {expr: `sets.intersects([1, 2], [1u, 2, 2.3])`}, - {expr: `sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]])`}, - {expr: `!sets.intersects([], [])`}, - {expr: `!sets.intersects([1], [])`}, - {expr: `!sets.intersects([1], [2])`}, - {expr: `!sets.intersects([1], ["1", 2])`}, - {expr: `!sets.intersects([1], [1.1, 2u])`}, + { + expr: `sets.intersects([1], [1])`, + estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + actualCost: 22, + }, + { + expr: `sets.intersects([1], [1, 1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([1, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([2, 1], [1])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([1], [1, 2])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([1], [1.0, 2])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `sets.intersects([1, 2], [2u, 2, 2.0])`, + estimatedCost: checker.CostEstimate{Min: 27, Max: 27}, + actualCost: 27, + }, + { + expr: `sets.intersects([1, 2], [1u, 2, 2.3])`, + estimatedCost: checker.CostEstimate{Min: 27, Max: 27}, + actualCost: 27, + }, + { + expr: `sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]])`, + estimatedCost: checker.CostEstimate{Min: 65, Max: 65}, + actualCost: 65, + }, + { + expr: `!sets.intersects([], [])`, + estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + actualCost: 22, + }, + { + expr: `!sets.intersects([1], [])`, + estimatedCost: checker.CostEstimate{Min: 22, Max: 22}, + actualCost: 22, + }, + { + expr: `!sets.intersects([1], [2])`, + estimatedCost: checker.CostEstimate{Min: 23, Max: 23}, + actualCost: 23, + }, + { + expr: `!sets.intersects([1], ["1", 2])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, + { + expr: `!sets.intersects([1], [1.1, 2u])`, + estimatedCost: checker.CostEstimate{Min: 24, Max: 24}, + actualCost: 24, + }, } - env := testSetsEnv(t) - for i, tst := range setsTests { + for _, tst := range setsTests { tc := tst - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Run(tc.expr, func(t *testing.T) { + env := testSetsEnv(t, tc.vars...) var asts []*cel.Ast pAst, iss := env.Parse(tc.expr) if iss.Err() != nil { @@ -86,20 +285,43 @@ func TestSets(t *testing.T) { if iss.Err() != nil { t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) } + + hints := map[string]uint64{} + if len(tc.hints) != 0 { + hints = tc.hints + } + est, err := env.EstimateCost(cAst, testSetsCostEstimator{hints: hints}) + if err != nil { + t.Fatalf("env.EstimateCost() failed: %v", err) + } + if !reflect.DeepEqual(est, tc.estimatedCost) { + t.Errorf("env.EstimateCost() got %v, wanted %v", est, tc.estimatedCost) + } asts = append(asts, cAst) for _, ast := range asts { - prg, err := env.Program(ast) + prgOpts := []cel.ProgramOption{} + if ast.IsChecked() { + prgOpts = append(prgOpts, cel.CostTracking(nil)) + } + prg, err := env.Program(ast, prgOpts...) if err != nil { t.Fatalf("env.Program() failed: %v", err) } - out, _, err := prg.Eval(cel.NoVars()) + in := tc.in + if in == nil { + in = map[string]any{} + } + out, det, err := prg.Eval(in) if err != nil { t.Fatalf("prg.Eval() failed: %v", err) } if out.Value() != true { t.Errorf("prg.Eval() got %v, wanted true for expr: %s", out.Value(), tc.expr) } + if det.ActualCost() != nil && *det.ActualCost() != tc.actualCost { + t.Errorf("prg.Eval() had cost %v, wanted %v", *det.ActualCost(), tc.actualCost) + } } }) } @@ -114,3 +336,18 @@ func testSetsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { } return env } + +type testSetsCostEstimator struct { + hints map[string]uint64 +} + +func (tc testSetsCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { + if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { + return &checker.SizeEstimate{Min: 0, Max: l} + } + return nil +} + +func (testSetsCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + return nil +} diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index 96faed2e..a9f9ede4 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -15,6 +15,7 @@ package interpreter import ( + "fmt" "math" "github.com/google/cel-go/common" @@ -133,6 +134,7 @@ func PresenceTestHasCost(hasCost bool) CostTrackerOption { func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) { tracker := &CostTracker{ Estimator: estimator, + functionTrackers: map[string]FunctionTracker{}, presenceTestHasCost: true, } for _, opt := range opts { @@ -144,9 +146,25 @@ func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (* return tracker, nil } +// FunctionCostTracker binds a function overload to a runtime FunctionTracker implementation. +// +// FunctionCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or +// optional cost tracking changes. +func FunctionCostTracker(function, overloadID string, fnTracker FunctionTracker) CostTrackerOption { + return func(tracker *CostTracker) error { + functionKey := fmt.Sprintf("%s|%s", function, overloadID) + tracker.functionTrackers[functionKey] = fnTracker + return nil + } +} + +// FunctionTracker computes the actual cost of evaluating the functions with the given arguments and result. +type FunctionTracker func(args []ref.Val, result ref.Val) *uint64 + // CostTracker represents the information needed for tracking runtime cost. type CostTracker struct { Estimator ActualCostEstimator + functionTrackers map[string]FunctionTracker Limit *uint64 presenceTestHasCost bool @@ -159,10 +177,20 @@ func (c *CostTracker) ActualCost() uint64 { return c.cost } -func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 { +func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result ref.Val) uint64 { var cost uint64 + if len(c.functionTrackers) != 0 { + functionKey := fmt.Sprintf("%s|%s", call.Function(), call.OverloadID()) + if tracker, found := c.functionTrackers[functionKey]; found { + callCost := tracker(args, result) + if callCost != nil { + cost += *callCost + return cost + } + } + } if c.Estimator != nil { - callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result) + callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), args, result) if callCost != nil { cost += *callCost return cost @@ -173,11 +201,11 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu switch call.OverloadID() { // O(n) functions case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString: - cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor)) + cost += uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor)) case overloads.InList: // If a list is composed entirely of constant values this is O(1), but we don't account for that here. // We just assume all list containment checks are O(n). - cost += c.actualSize(argValues[1]) + cost += c.actualSize(args[1]) // O(min(m, n)) functions case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString, overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes, @@ -185,8 +213,8 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu // When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.), // the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost // of 1. - lhsSize := c.actualSize(argValues[0]) - rhsSize := c.actualSize(argValues[1]) + lhsSize := c.actualSize(args[0]) + rhsSize := c.actualSize(args[1]) minSize := lhsSize if rhsSize < minSize { minSize = rhsSize @@ -195,23 +223,23 @@ func (c *CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resu // O(m+n) functions case overloads.AddString, overloads.AddBytes: // In the worst case scenario, we would need to reallocate a new backing store and copy both operands over. - cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor)) + cost += uint64(math.Ceil(float64(c.actualSize(args[0])+c.actualSize(args[1])) * common.StringTraversalCostFactor)) // O(nm) functions case overloads.MatchesString: // https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 // in case where string is empty but regex is still expensive. - strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor)) + strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(args[0]))) * common.StringTraversalCostFactor)) // We don't know how many expressions are in the regex, just the string length (a huge // improvement here would be to somehow get a count the number of expressions in the regex or // how many states are in the regex state machine and use that to measure regex cost). // For now, we're making a guess that each expression in a regex is typically at least 4 chars // in length. - regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor)) + regexCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.RegexStringLengthCostFactor)) cost += strCost * regexCost case overloads.ContainsString: - strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor)) - substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor)) + strCost := uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor)) + substrCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.StringTraversalCostFactor)) cost += strCost * substrCost default: diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index 9a700a4b..d0c8c1de 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -16,6 +16,7 @@ package interpreter import ( "fmt" + "math" "math/rand" "reflect" "strings" @@ -29,6 +30,7 @@ import ( "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/parser" proto3pb "github.com/google/cel-go/test/proto3pb" @@ -727,6 +729,25 @@ func TestRuntimeCost(t *testing.T) { want: 6, in: map[string]any{"str1": "val1", "str2": "val2222222"}, }, + { + name: "str concat custom cost tracker", + expr: `"abcdefg".contains(str1 + str2)`, + vars: []*decls.VariableDecl{ + decls.NewVariable("str1", types.StringType), + decls.NewVariable("str2", types.StringType), + }, + options: []CostTrackerOption{ + FunctionCostTracker(overloads.Contains, overloads.ContainsString, + func(args []ref.Val, result ref.Val) *uint64 { + strCost := uint64(math.Ceil(float64(actualSize(args[0])) * 0.2)) + substrCost := uint64(math.Ceil(float64(actualSize(args[1])) * 0.2)) + cost := strCost * substrCost + return &cost + }), + }, + want: 10, + in: map[string]any{"str1": "val1", "str2": "val2222222"}, + }, { name: "at limit", expr: `"abcdefg".contains(str1 + str2)`, @@ -803,3 +824,10 @@ func TestRuntimeCost(t *testing.T) { }) } } + +func actualSize(val ref.Val) uint64 { + if sz, ok := val.(traits.Sizer); ok { + return uint64(sz.Size().(types.Int)) + } + return 1 +} From 51e493665cf201f94a77ba1add96d1aac4abbcea Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Thu, 26 Oct 2023 15:31:25 -0700 Subject: [PATCH 2/8] Added doc comment on CostEstimatorOptions --- cel/options.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/cel/options.go b/cel/options.go index 89825c36..3c53e21a 100644 --- a/cel/options.go +++ b/cel/options.go @@ -472,6 +472,14 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption { } } +// CostEstimatorOptions configure type-check time options for estimating expression cost. +func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption { + return func(e *Env) (*Env, error) { + e.costOptions = append(e.costOptions, costOpts...) + return e, nil + } +} + // CostTrackerOptions configures a set of options for cost-tracking. // // Note, CostTrackerOptions is a no-op unless CostTracking is also enabled. @@ -641,13 +649,6 @@ func ParserExpressionSizeLimit(limit int) EnvOption { } } -func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption { - return func(e *Env) (*Env, error) { - e.costOptions = append(e.costOptions, costOpts...) - return e, nil - } -} - func maybeInteropProvider(provider any) (types.Provider, error) { switch p := provider.(type) { case types.Provider: From 33c6213dee0d848a3b76d3e641bb3fccd3e9201c Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 27 Oct 2023 11:43:55 -0700 Subject: [PATCH 3/8] Fix BUILD rules and sets test expectations --- ext/BUILD.bazel | 2 +- ext/sets_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index c08e85cd..57013022 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -22,7 +22,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//cel:go_default_library", - "//checker/decls:go_default_library", + "//checker:go_default_library", "//common/ast:go_default_library", "//common/overloads:go_default_library", "//common/types:go_default_library", diff --git a/ext/sets_test.go b/ext/sets_test.go index 1d402b27..2c9c791a 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -50,7 +50,7 @@ func TestSets(t *testing.T) { in: map[string]any{"x": []int64{5, 4, 3, 2, 1}}, // min cost is input 'x' length 0, 10 for list creation, 2 for arg costs // max cost is effectively infinite due to missing size hint for 'x' - estimatedCost: checker.CostEstimate{Min: 12, Max: 9223372036854775820}, + estimatedCost: checker.CostEstimate{Min: 12, Max: 18446744073709551615}, // actual cost is 'x' length 5 * list literal length 5, 10 for list creation, 2 for arg cost actualCost: 37, }, From 26b823069aeb0b1a80af5849d921aebcd9d30f35 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 27 Oct 2023 11:53:10 -0700 Subject: [PATCH 4/8] Shift API to just specify support for estimation / tracking by overload --- checker/cost.go | 19 ++++++++----------- checker/cost_test.go | 2 +- ext/sets.go | 12 ++++++------ interpreter/runtimecost.go | 19 ++++++++----------- interpreter/runtimecost_test.go | 2 +- 5 files changed, 24 insertions(+), 30 deletions(-) diff --git a/checker/cost.go b/checker/cost.go index c8e21297..04bfa70b 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -15,7 +15,6 @@ package checker import ( - "fmt" "math" "github.com/google/cel-go/common" @@ -260,7 +259,7 @@ type coster struct { computedSizes map[int64]SizeEstimate checkedAST *ast.AST estimator CostEstimator - functionEstimators map[string]FunctionEstimator + overloadEstimators map[string]FunctionEstimator // presenceTestCost will either be a zero or one based on whether has() macros count against cost computations. presenceTestCost CostEstimate } @@ -305,14 +304,13 @@ func PresenceTestHasCost(hasCost bool) CostOption { // FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair. type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate -// FunctionCostEstimate binds a FunctionCoster to a specific function, overload pair. +// OverloadCostEstimate binds a FunctionCoster to a specific function overload ID. // -// When a FunctionCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to +// When a OverloadCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to // the Cost() call. -func FunctionCostEstimate(function, overloadID string, functionCoster FunctionEstimator) CostOption { +func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) CostOption { return func(c *coster) error { - functionKey := fmt.Sprintf("%s|%s", function, overloadID) - c.functionEstimators[functionKey] = functionCoster + c.overloadEstimators[overloadID] = functionCoster return nil } } @@ -322,7 +320,7 @@ func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEs c := &coster{ checkedAST: checked, estimator: estimator, - functionEstimators: map[string]FunctionEstimator{}, + overloadEstimators: map[string]FunctionEstimator{}, exprPath: map[int64][]string{}, iterRanges: map[string][]int64{}, computedSizes: map[int64]SizeEstimate{}, @@ -537,9 +535,8 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args } return sum } - if len(c.functionEstimators) != 0 { - functionKey := fmt.Sprintf("%s|%s", function, overloadID) - if estimator, found := c.functionEstimators[functionKey]; found { + if len(c.overloadEstimators) != 0 { + if estimator, found := c.overloadEstimators[overloadID]; found { if est := estimator(c.estimator, target, args); est != nil { callEst := *est return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize} diff --git a/checker/cost_test.go b/checker/cost_test.go index 55d852d8..c6c95643 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -430,7 +430,7 @@ func TestCost(t *testing.T) { }, hints: map[string]uint64{"str1": 10, "str2": 10}, options: []CostOption{ - FunctionCostEstimate(overloads.Contains, overloads.ContainsString, + OverloadCostEstimate(overloads.ContainsString, func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate { if target != nil && len(args) == 1 { strSize := estimateSize(estimator, *target).MultiplyByCostFactor(0.2) diff --git a/ext/sets.go b/ext/sets.go index ea18618d..61c3205c 100644 --- a/ext/sets.go +++ b/ext/sets.go @@ -100,10 +100,10 @@ func (setsLib) CompileOptions() []cel.EnvOption { cel.Overload("list_sets_intersects_list", []*cel.Type{listType, listType}, cel.BoolType, cel.BinaryBinding(setsIntersects))), cel.CostEstimatorOptions( - checker.FunctionCostEstimate("sets.contains", "list_sets_contains_list", estimateSetsCost(1)), - checker.FunctionCostEstimate("sets.intersects", "list_sets_intersects_list", estimateSetsCost(1)), + checker.OverloadCostEstimate("list_sets_contains_list", estimateSetsCost(1)), + checker.OverloadCostEstimate("list_sets_intersects_list", estimateSetsCost(1)), // equivalence requires potentially two m*n comparisons to ensure each list is contained by the other - checker.FunctionCostEstimate("sets.equivalent", "list_sets_equivalent_list", estimateSetsCost(2)), + checker.OverloadCostEstimate("list_sets_equivalent_list", estimateSetsCost(2)), ), } } @@ -112,9 +112,9 @@ func (setsLib) CompileOptions() []cel.EnvOption { func (setsLib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{ cel.CostTrackerOptions( - interpreter.FunctionCostTracker("sets.contains", "list_sets_contains_list", trackSetsCost(1)), - interpreter.FunctionCostTracker("sets.intersects", "list_sets_intersects_list", trackSetsCost(1)), - interpreter.FunctionCostTracker("sets.equivalent", "list_sets_equivalent_list", trackSetsCost(2)), + interpreter.OverloadCostTracker("list_sets_contains_list", trackSetsCost(1)), + interpreter.OverloadCostTracker("list_sets_intersects_list", trackSetsCost(1)), + interpreter.OverloadCostTracker("list_sets_equivalent_list", trackSetsCost(2)), ), } } diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index a9f9ede4..b9b307c1 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -15,7 +15,6 @@ package interpreter import ( - "fmt" "math" "github.com/google/cel-go/common" @@ -134,7 +133,7 @@ func PresenceTestHasCost(hasCost bool) CostTrackerOption { func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) { tracker := &CostTracker{ Estimator: estimator, - functionTrackers: map[string]FunctionTracker{}, + overloadTrackers: map[string]FunctionTracker{}, presenceTestHasCost: true, } for _, opt := range opts { @@ -146,14 +145,13 @@ func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (* return tracker, nil } -// FunctionCostTracker binds a function overload to a runtime FunctionTracker implementation. +// OverloadCostTracker binds an overload ID to a runtime FunctionTracker implementation. // -// FunctionCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or +// OverloadCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or // optional cost tracking changes. -func FunctionCostTracker(function, overloadID string, fnTracker FunctionTracker) CostTrackerOption { +func OverloadCostTracker(overloadID string, fnTracker FunctionTracker) CostTrackerOption { return func(tracker *CostTracker) error { - functionKey := fmt.Sprintf("%s|%s", function, overloadID) - tracker.functionTrackers[functionKey] = fnTracker + tracker.overloadTrackers[overloadID] = fnTracker return nil } } @@ -164,7 +162,7 @@ type FunctionTracker func(args []ref.Val, result ref.Val) *uint64 // CostTracker represents the information needed for tracking runtime cost. type CostTracker struct { Estimator ActualCostEstimator - functionTrackers map[string]FunctionTracker + overloadTrackers map[string]FunctionTracker Limit *uint64 presenceTestHasCost bool @@ -179,9 +177,8 @@ func (c *CostTracker) ActualCost() uint64 { func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result ref.Val) uint64 { var cost uint64 - if len(c.functionTrackers) != 0 { - functionKey := fmt.Sprintf("%s|%s", call.Function(), call.OverloadID()) - if tracker, found := c.functionTrackers[functionKey]; found { + if len(c.overloadTrackers) != 0 { + if tracker, found := c.overloadTrackers[call.OverloadID()]; found { callCost := tracker(args, result) if callCost != nil { cost += *callCost diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index d0c8c1de..1c6ac124 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -737,7 +737,7 @@ func TestRuntimeCost(t *testing.T) { decls.NewVariable("str2", types.StringType), }, options: []CostTrackerOption{ - FunctionCostTracker(overloads.Contains, overloads.ContainsString, + OverloadCostTracker(overloads.ContainsString, func(args []ref.Val, result ref.Val) *uint64 { strCost := uint64(math.Ceil(float64(actualSize(args[0])) * 0.2)) substrCost := uint64(math.Ceil(float64(actualSize(args[1])) * 0.2)) From e5d4cea73c7824a7ef7dd4932cf111006c72fdf0 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 27 Oct 2023 12:11:31 -0700 Subject: [PATCH 5/8] Test fix --- ext/sets_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/sets_test.go b/ext/sets_test.go index 2c9c791a..ddaf96d8 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -15,6 +15,7 @@ package ext import ( + "math" "reflect" "strings" "testing" @@ -50,7 +51,7 @@ func TestSets(t *testing.T) { in: map[string]any{"x": []int64{5, 4, 3, 2, 1}}, // min cost is input 'x' length 0, 10 for list creation, 2 for arg costs // max cost is effectively infinite due to missing size hint for 'x' - estimatedCost: checker.CostEstimate{Min: 12, Max: 18446744073709551615}, + estimatedCost: checker.CostEstimate{Min: 12, Max: math.MaxUint64}, // actual cost is 'x' length 5 * list literal length 5, 10 for list creation, 2 for arg cost actualCost: 37, }, From b327c0b6eca92df0fd5923faa0c16fc26c5f37fe Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 27 Oct 2023 13:46:43 -0700 Subject: [PATCH 6/8] Fix ContextEval support for estimator, tracker overloads --- cel/cel_test.go | 44 ++++++++++++++++++++++++++++++-------------- cel/program.go | 5 ++++- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 4addeb76..97b81ef4 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -92,7 +92,12 @@ func Test_ExampleWithBuiltins(t *testing.T) { } func TestEval(t *testing.T) { - env, err := NewEnv(Variable("input", ListType(IntType))) + env, err := NewEnv( + Variable("input", ListType(IntType)), + CostEstimatorOptions( + checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear), + ), + ) if err != nil { t.Fatalf("NewEnv() failed: %v", err) } @@ -115,6 +120,9 @@ func TestEval(t *testing.T) { ctx := context.Background() prgOpts := []ProgramOption{ CostTracking(testRuntimeCostEstimator{}), + CostTrackerOptions( + interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear), + ), EvalOptions(OptOptimize, OptTrackCost), InterruptCheckFrequency(100), } @@ -1512,7 +1520,13 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { if tc.hints == nil { tc.hints = map[string]uint64{} } - env := testEnv(t, tc.decls...) + envOpts := []EnvOption{ + CostEstimatorOptions( + checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear), + ), + } + envOpts = append(envOpts, tc.decls...) + env := testEnv(t, envOpts...) ast, iss := env.Compile(tc.expr) if iss.Err() != nil { t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err()) @@ -1531,7 +1545,12 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) { t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err()) } // Evaluate expression. - program, err := env.Program(checkedAst, CostTracking(testRuntimeCostEstimator{})) + program, err := env.Program(checkedAst, + CostTracking(testRuntimeCostEstimator{}), + CostTrackerOptions( + interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear), + ), + ) if err != nil { t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err) } @@ -2779,16 +2798,15 @@ func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeE } func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { - switch overloadID { - case overloads.TimestampToYear: - return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}} - } return nil } -type testRuntimeCostEstimator struct { +func estimateTimestampToYear(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}} } +type testRuntimeCostEstimator struct{} + var timeToYearCost uint64 = 7 func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 { @@ -2804,13 +2822,11 @@ func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []r argsSize[i] = 1 } } + return nil +} - switch overloadID { - case overloads.TimestampToYear: - return &timeToYearCost - default: - return nil - } +func trackTimestampToYear(args []ref.Val, result ref.Val) *uint64 { + return &timeToYearCost } func testEnv(t testing.TB, opts ...EnvOption) *Env { diff --git a/cel/program.go b/cel/program.go index 50fb120f..ece9fbda 100644 --- a/cel/program.go +++ b/cel/program.go @@ -384,7 +384,10 @@ func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalD // new EvalState instance for each call to ensure that unique evaluations yield unique stateful // results. state := interpreter.NewEvalState() - costTracker := &interpreter.CostTracker{} + costTracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, nil, err + } det := &EvalDetails{state: state, costTracker: costTracker} // Generate a new instance of the interpretable using the factory configured during the call to From e0530623f62f2aed2be5c32579631c0d9d4b48a0 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 27 Oct 2023 13:59:49 -0700 Subject: [PATCH 7/8] Add a guard around the set estimates --- ext/sets.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ext/sets.go b/ext/sets.go index 61c3205c..833c15f6 100644 --- a/ext/sets.go +++ b/ext/sets.go @@ -155,10 +155,13 @@ func setsEquivalent(listA, listB ref.Val) ref.Val { func estimateSetsCost(costFactor float64) checker.FunctionEstimator { return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { - arg0Size := estimateSize(estimator, args[0]) - arg1Size := estimateSize(estimator, args[1]) - costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate) - return &checker.CallEstimate{CostEstimate: costEstimate} + if len(args) == 2 { + arg0Size := estimateSize(estimator, args[0]) + arg1Size := estimateSize(estimator, args[1]) + costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate) + return &checker.CallEstimate{CostEstimate: costEstimate} + } + return nil } } From 7d5fa424153b59c69bf1188590810381c7dadc9e Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 27 Oct 2023 14:05:59 -0700 Subject: [PATCH 8/8] Adjust cost factor clamping --- checker/cost.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/checker/cost.go b/checker/cost.go index 04bfa70b..3470d0a3 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -226,7 +226,7 @@ func addUint64NoOverflow(x, y uint64) uint64 { // multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64 // is returned. func multiplyUint64NoOverflow(x, y uint64) uint64 { - if x > 0 && y > 0 && x > math.MaxUint64/y { + if y != 0 && x > math.MaxUint64/y { return math.MaxUint64 } return x * y @@ -238,7 +238,11 @@ func multiplyByCostFactor(x uint64, y float64) uint64 { if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y { return math.MaxUint64 } - return uint64(math.Ceil(xFloat * y)) + ceil := math.Ceil(xFloat * y) + if ceil >= doubleTwoTo64 { + return math.MaxUint64 + } + return uint64(ceil) } var ( @@ -692,3 +696,7 @@ func isScalar(t *types.Type) bool { } return false } + +var ( + doubleTwoTo64 = math.Ldexp(1.0, 64) +)