Skip to content

Commit

Permalink
Sets cost estimation and tracking options
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones committed Oct 26, 2023
1 parent 3a8e854 commit 0176a41
Show file tree
Hide file tree
Showing 7 changed files with 422 additions and 78 deletions.
10 changes: 9 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ type Env struct {
appliedFeatures map[int]bool
libraries map[string]bool
validators []ASTValidator
costOptions []checker.CostOption

// Internal parser representation
prsr *parser.Parser
Expand Down Expand Up @@ -191,6 +192,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
libraries: map[string]bool{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
costOptions: []checker.CostOption{},
}).configure(opts)
}

Expand Down Expand Up @@ -365,6 +367,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)
costOptsCopy := make([]checker.CostOption, len(e.costOptions))
copy(costOptsCopy, e.costOptions)

ext := &Env{
Container: e.Container,
Expand All @@ -380,6 +384,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
costOptions: costOptsCopy,
}
return ext.configure(opts)
}
Expand Down Expand Up @@ -556,7 +561,10 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
return checker.Cost(ast.impl, estimator, opts...)
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(ast.impl, estimator, extendedOpts...)
}

// configure applies a series of EnvOptions to the current environment.
Expand Down
15 changes: 15 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"

"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -471,6 +472,13 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
}
}

func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption {
return func(p *prog) (*prog, error) {
p.costOptions = append(p.costOptions, costOpts...)
return p, nil
}
}

// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
Expand Down Expand Up @@ -630,6 +638,13 @@ func ParserExpressionSizeLimit(limit int) EnvOption {
}
}

func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption {
return func(e *Env) (*Env, error) {
e.costOptions = append(e.costOptions, costOpts...)
return e, nil
}
}

func maybeInteropProvider(provider any) (types.Provider, error) {
switch p := provider.(type) {
case types.Provider:
Expand Down
30 changes: 24 additions & 6 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (ed *EvalDetails) State() interpreter.EvalState {
// ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled.
// Otherwise, returns nil if the cost was not enabled.
func (ed *EvalDetails) ActualCost() *uint64 {
if ed.costTracker == nil {
if ed == nil || ed.costTracker == nil {
return nil
}
cost := ed.costTracker.ActualCost()
Expand All @@ -129,10 +129,14 @@ type prog struct {
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
}

func (p *prog) clone() *prog {
costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions))
copy(costOptsCopy, p.costOptions)

return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
Expand All @@ -154,9 +158,10 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}

// Configure the program via the ProgramOption values.
Expand Down Expand Up @@ -213,6 +218,12 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
for _, costOpt := range p.costOptions {
err := costOpt(costTracker)
if err != nil {
return nil, err
}
}
// Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This
// prevents the underlying memory from being shared between factory function calls causing
// undesired mutations.
Expand Down Expand Up @@ -325,7 +336,11 @@ type progGen struct {
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{})
tracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, err
}
_, err = factory(interpreter.NewEvalState(), tracker)
if err != nil {
return nil, err
}
Expand All @@ -338,7 +353,10 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}

// Generate a new instance of the interpretable using the factory configured during the call to
Expand Down
46 changes: 36 additions & 10 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package checker

import (
"fmt"
"math"

"github.com/google/cel-go/common"
Expand Down Expand Up @@ -256,9 +257,10 @@ type coster struct {
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedAST *ast.AST
estimator CostEstimator
computedSizes map[int64]SizeEstimate
checkedAST *ast.AST
estimator CostEstimator
functionEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
Expand Down Expand Up @@ -299,15 +301,31 @@ func PresenceTestHasCost(hasCost bool) CostOption {
}
}

// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair.
type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate

// FunctionCostEstimate binds a FunctionCoster to a specific function, overload pair.
//
// When a FunctionCostEstimate is provided, it will override the default cost calculation of the CostEstimator
// provided to the Cost() call.
func FunctionCostEstimate(function, overloadID string, functionCoster FunctionEstimator) CostOption {
return func(c *coster) error {
functionKey := fmt.Sprintf("%s|%s", function, overloadID)
c.functionEstimators[functionKey] = functionCoster
return nil
}
}

// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedAST: checked,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
checkedAST: checked,
estimator: estimator,
functionEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
Expand Down Expand Up @@ -518,7 +536,15 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
return sum
}

if len(c.functionEstimators) != 0 {
functionKey := fmt.Sprintf("%s|%s", function, overloadID)
if estimator, found := c.functionEstimators[functionKey]; found {
if est := estimator(c.estimator, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
}
}
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
Expand Down
58 changes: 57 additions & 1 deletion ext/sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
package ext

import (
"math"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)

// Sets returns a cel.EnvOption to configure namespaced set relationship
Expand Down Expand Up @@ -95,12 +99,24 @@ func (setsLib) CompileOptions() []cel.EnvOption {
cel.Function("sets.intersects",
cel.Overload("list_sets_intersects_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsIntersects))),
cel.CostEstimatorOptions(
checker.FunctionCostEstimate("sets.contains", "list_sets_contains_list", estimateSetsCost(1)),
checker.FunctionCostEstimate("sets.intersects", "list_sets_intersects_list", estimateSetsCost(1)),
// equivalence requires potentially two m*n comparisons to ensure each list is contained by the other
checker.FunctionCostEstimate("sets.equivalent", "list_sets_equivalent_list", estimateSetsCost(2)),
),
}
}

// ProgramOptions implements the Library interface method.
func (setsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
return []cel.ProgramOption{
cel.CostTrackerOptions(
interpreter.FunctionCostTracker("sets.contains", "list_sets_contains_list", trackSetsCost(1)),
interpreter.FunctionCostTracker("sets.intersects", "list_sets_intersects_list", trackSetsCost(1)),
interpreter.FunctionCostTracker("sets.equivalent", "list_sets_equivalent_list", trackSetsCost(2)),
),
}
}

func setsIntersects(listA, listB ref.Val) ref.Val {
Expand Down Expand Up @@ -136,3 +152,43 @@ func setsEquivalent(listA, listB ref.Val) ref.Val {
}
return setsContains(listB, listA)
}

func estimateSetsCost(costFactor float64) checker.FunctionEstimator {
return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
arg0Size := estimateSize(estimator, args[0])
arg1Size := estimateSize(estimator, args[1])
costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate)
return &checker.CallEstimate{CostEstimate: costEstimate}
}
}

func estimateSize(estimator checker.CostEstimator, node checker.AstNode) checker.SizeEstimate {
if l := node.ComputedSize(); l != nil {
return *l
}
if l := estimator.EstimateSize(node); l != nil {
return *l
}
return checker.SizeEstimate{Min: 0, Max: math.MaxUint64}
}

func trackSetsCost(costFactor float64) interpreter.FunctionTracker {
return func(args []ref.Val, _ ref.Val) *uint64 {
lhsSize := actualSize(args[0])
rhsSize := actualSize(args[1])
cost := callCost + uint64(float64(lhsSize*rhsSize)*costFactor)
return &cost
}
}

func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}

var (
callCostEstimate = checker.CostEstimate{Min: 1, Max: 1}
callCost = uint64(1)
)
Loading

0 comments on commit 0176a41

Please sign in to comment.