Skip to content

Commit

Permalink
Backport google#850: Sets cost estimation and tracking options
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones committed Oct 27, 2023
1 parent 52ee283 commit 2e4306b
Show file tree
Hide file tree
Showing 11 changed files with 603 additions and 124 deletions.
54 changes: 35 additions & 19 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ func Test_ExampleWithBuiltins(t *testing.T) {
}

func TestEval(t *testing.T) {
env, err := NewEnv(Variable("input", ListType(IntType)))
env, err := NewEnv(
Variable("input", ListType(IntType)),
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
Expand All @@ -114,6 +119,9 @@ func TestEval(t *testing.T) {
ctx := context.Background()
prgOpts := []ProgramOption{
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
EvalOptions(OptOptimize, OptTrackCost),
InterruptCheckFrequency(100),
}
Expand Down Expand Up @@ -1338,7 +1346,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
name string
expr string
decls []EnvOption
hints map[string]int64
hints map[string]uint64
want checker.CostEstimate
in any
}{
Expand All @@ -1362,7 +1370,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
Variable("str1", StringType),
Variable("str2", StringType),
},
hints: map[string]int64{"str1": 10, "str2": 10},
hints: map[string]uint64{"str1": 10, "str2": 10},
want: checker.CostEstimate{Min: 2, Max: 6},
in: map[string]any{"str1": "val1111111", "str2": "val2222222"},
},
Expand All @@ -1373,9 +1381,15 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if tc.hints == nil {
tc.hints = map[string]int64{}
tc.hints = map[string]uint64{}
}
env := testEnv(t, tc.decls...)
envOpts := []EnvOption{
CostEstimatorOptions(
checker.OverloadCostEstimate(overloads.TimestampToYear, estimateTimestampToYear),
),
}
envOpts = append(envOpts, tc.decls...)
env := testEnv(t, envOpts...)
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err())
Expand All @@ -1394,7 +1408,12 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
t.Fatalf(`Env.Check(ast *Ast) failed to check expression: %v`, iss.Err())
}
// Evaluate expression.
program, err := env.Program(checkedAst, CostTracking(testRuntimeCostEstimator{}))
program, err := env.Program(checkedAst,
CostTracking(testRuntimeCostEstimator{}),
CostTrackerOptions(
interpreter.OverloadCostTracker(overloads.TimestampToYear, trackTimestampToYear),
),
)
if err != nil {
t.Fatalf(`Env.Program(ast *Ast, opts ...ProgramOption) failed to construct program: %v`, err)
}
Expand Down Expand Up @@ -2631,27 +2650,26 @@ func BenchmarkDynamicDispatch(b *testing.B) {

// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]int64
hints map[string]uint64
}

func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok {
return &checker.SizeEstimate{Min: 0, Max: uint64(l)}
return &checker.SizeEstimate{Min: 0, Max: l}
}
return nil
}

func (tc testCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
switch overloadID {
case overloads.TimestampToYear:
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
}
return nil
}

type testRuntimeCostEstimator struct {
func estimateTimestampToYear(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
}

type testRuntimeCostEstimator struct{}

var timeToYearCost uint64 = 7

func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 {
Expand All @@ -2667,13 +2685,11 @@ func (e testRuntimeCostEstimator) CallCost(function, overloadID string, args []r
argsSize[i] = 1
}
}
return nil
}

switch overloadID {
case overloads.TimestampToYear:
return &timeToYearCost
default:
return nil
}
func trackTimestampToYear(args []ref.Val, result ref.Val) *uint64 {
return &timeToYearCost
}

func testEnv(t testing.TB, opts ...EnvOption) *Env {
Expand Down
10 changes: 9 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ type Env struct {
appliedFeatures map[int]bool
libraries map[string]bool
validators []ASTValidator
costOptions []checker.CostOption

// Internal parser representation
prsr *parser.Parser
Expand Down Expand Up @@ -181,6 +182,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
libraries: map[string]bool{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
costOptions: []checker.CostOption{},
}).configure(opts)
}

Expand Down Expand Up @@ -356,6 +358,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)
costOptsCopy := make([]checker.CostOption, len(e.costOptions))
copy(costOptsCopy, e.costOptions)

ext := &Env{
Container: e.Container,
Expand All @@ -371,6 +375,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
costOptions: costOptsCopy,
}
return ext.configure(opts)
}
Expand Down Expand Up @@ -557,7 +562,10 @@ func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...ch
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
return checker.Cost(checked, estimator, opts...)
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(checked, estimator, extendedOpts...)
}

// configure applies a series of EnvOptions to the current environment.
Expand Down
19 changes: 19 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"

"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -469,6 +470,24 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
}
}

// CostEstimatorOptions configure type-check time options for estimating expression cost.
func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption {
return func(e *Env) (*Env, error) {
e.costOptions = append(e.costOptions, costOpts...)
return e, nil
}
}

// CostTrackerOptions configures a set of options for cost-tracking.
//
// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled.
func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption {
return func(p *prog) (*prog, error) {
p.costOptions = append(p.costOptions, costOpts...)
return p, nil
}
}

// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
Expand Down
35 changes: 28 additions & 7 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (ed *EvalDetails) State() interpreter.EvalState {
// ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled.
// Otherwise, returns nil if the cost was not enabled.
func (ed *EvalDetails) ActualCost() *uint64 {
if ed.costTracker == nil {
if ed == nil || ed.costTracker == nil {
return nil
}
cost := ed.costTracker.ActualCost()
Expand All @@ -130,10 +130,14 @@ type prog struct {
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
}

func (p *prog) clone() *prog {
costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions))
copy(costOptsCopy, p.costOptions)

return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
Expand All @@ -155,9 +159,10 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}

// Configure the program via the ProgramOption values.
Expand Down Expand Up @@ -242,6 +247,12 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
for _, costOpt := range p.costOptions {
err := costOpt(costTracker)
if err != nil {
return nil, err
}
}
// Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This
// prevents the underlying memory from being shared between factory function calls causing
// undesired mutations.
Expand Down Expand Up @@ -371,7 +382,11 @@ type progGen struct {
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{})
tracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, err
}
_, err = factory(interpreter.NewEvalState(), tracker)
if err != nil {
return nil, err
}
Expand All @@ -384,7 +399,10 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}

// Generate a new instance of the interpretable using the factory configured during the call to
Expand Down Expand Up @@ -412,7 +430,10 @@ func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalD
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}

// Generate a new instance of the interpretable using the factory configured during the call to
Expand Down
56 changes: 44 additions & 12 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func addUint64NoOverflow(x, y uint64) uint64 {
// multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func multiplyUint64NoOverflow(x, y uint64) uint64 {
if x > 0 && y > 0 && x > math.MaxUint64/y {
if y != 0 && x > math.MaxUint64/y {
return math.MaxUint64
}
return x * y
Expand All @@ -242,7 +242,11 @@ func multiplyByCostFactor(x uint64, y float64) uint64 {
if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y {
return math.MaxUint64
}
return uint64(math.Ceil(xFloat * y))
ceil := math.Ceil(xFloat * y)
if ceil >= doubleTwoTo64 {
return math.MaxUint64
}
return uint64(ceil)
}

var (
Expand All @@ -260,9 +264,10 @@ type coster struct {
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
estimator CostEstimator
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
estimator CostEstimator
overloadEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
Expand Down Expand Up @@ -291,6 +296,7 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
type CostOption func(*coster) error

// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
//
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
Expand All @@ -303,15 +309,30 @@ func PresenceTestHasCost(hasCost bool) CostOption {
}
}

// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair.
type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate

// OverloadCostEstimate binds a FunctionCoster to a specific function overload ID.
//
// When a OverloadCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to
// the Cost() call.
func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) CostOption {
return func(c *coster) error {
c.overloadEstimators[overloadID] = functionCoster
return nil
}
}

// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedAST: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
checkedAST: checker,
estimator: estimator,
overloadEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
Expand Down Expand Up @@ -532,7 +553,14 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
return sum
}

if len(c.overloadEstimators) != 0 {
if estimator, found := c.overloadEstimators[overloadID]; found {
if est := estimator(c.estimator, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
}
}
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
Expand Down Expand Up @@ -682,3 +710,7 @@ func isScalar(t *types.Type) bool {
}
return false
}

var (
doubleTwoTo64 = math.Ldexp(1.0, 64)
)
Loading

0 comments on commit 2e4306b

Please sign in to comment.