Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional nil-safety checks with corresponding test updates #1073

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,14 @@ func TestResidualAstMacros(t *testing.T) {
}
}

func TestResidualAstNil(t *testing.T) {
env := testEnv(t)
ast, err := env.ResidualAst(nil, nil)
if err == nil || !strings.Contains(err.Error(), "unsupported expr") {
t.Errorf("env.ResidualAst() got (%v, %v) wanted unsupported expr error", ast, err)
}
}

func BenchmarkEvalOptions(b *testing.B) {
env := testEnv(b,
Variable("ai", IntType),
Expand Down Expand Up @@ -1323,7 +1331,7 @@ func TestEnvExtensionIsolation(t *testing.T) {
func TestVariadicLogicalOperators(t *testing.T) {
env := testEnv(t, variadicLogicalOperatorASTs())
ast, iss := env.Compile(
`(false || false || false || false || true) &&
`(false || false || false || false || true) &&
(true && true && true && true && false)`)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
Expand Down Expand Up @@ -2293,7 +2301,7 @@ func TestOptionalValuesCompile(t *testing.T) {
if iss.Err() != nil {
t.Fatalf("%v failed: %v", tc.expr, iss.Err())
}
for id, reference := range ast.impl.ReferenceMap() {
for id, reference := range ast.NativeRep().ReferenceMap() {
other, found := tc.references[id]
if !found {
t.Errorf("Compile(%v) expected reference %d: %v", tc.expr, id, reference)
Expand Down Expand Up @@ -2955,6 +2963,15 @@ func BenchmarkDynamicDispatch(b *testing.B) {
})
}

func TestAstProgramNilValue(t *testing.T) {
var ast *Ast = nil
env := testEnv(t)
prg, err := env.Program(ast)
if err == nil || !strings.Contains(err.Error(), "unsupported expr") {
t.Errorf("env.Program() got (%v,%v) wanted unsupported expr error", prg, err)
}
}

// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]uint64
Expand Down
5 changes: 3 additions & 2 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State())
ast := a.NativeRep()
pruned := interpreter.PruneAst(ast.Expr(), ast.SourceInfo().MacroCalls(), details.State())
newAST := &Ast{source: a.Source(), impl: pruned}
expr, err := AstToString(newAST)
if err != nil {
Expand All @@ -582,7 +583,7 @@ func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...ch
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(ast.impl, estimator, extendedOpts...)
return checker.Cost(ast.NativeRep(), estimator, extendedOpts...)
}

// configure applies a series of EnvOptions to the current environment.
Expand Down
8 changes: 4 additions & 4 deletions cel/folding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,11 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) {
t.Fatalf("Compile() failed: %v", iss.Err())
}
preOpt := newIDCollector()
ast.PostOrderVisit(checked.impl.Expr(), preOpt)
ast.PostOrderVisit(checked.NativeRep().Expr(), preOpt)
if !reflect.DeepEqual(preOpt.IDs(), tc.ids) {
t.Errorf("Compile() got ids %v, expected %v", preOpt.IDs(), tc.ids)
}
for id, call := range checked.impl.SourceInfo().MacroCalls() {
for id, call := range checked.NativeRep().SourceInfo().MacroCalls() {
macroText, found := tc.macros[id]
if !found {
t.Fatalf("Compile() did not find macro %d", id)
Expand Down Expand Up @@ -682,11 +682,11 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
}
postOpt := newIDCollector()
ast.PostOrderVisit(optimized.impl.Expr(), postOpt)
ast.PostOrderVisit(optimized.NativeRep().Expr(), postOpt)
if !reflect.DeepEqual(postOpt.IDs(), tc.normalizedIDs) {
t.Errorf("Optimize() got ids %v, expected %v", postOpt.IDs(), tc.normalizedIDs)
}
for id, call := range optimized.impl.SourceInfo().MacroCalls() {
for id, call := range optimized.NativeRep().SourceInfo().MacroCalls() {
macroText, found := tc.normalizedMacros[id]
if !found {
t.Fatalf("Optimize() did not find macro %d", id)
Expand Down
2 changes: 1 addition & 1 deletion cel/inlining.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func NewInlineVariable(name string, definition *Ast) *InlineVariable {
// If the variable occurs more than once, the provided alias will be used to replace the expressions
// where the variable name occurs.
func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable {
return &InlineVariable{name: name, alias: alias, def: definition.impl}
return &InlineVariable{name: name, alias: alias, def: definition.NativeRep()}
}

// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions.
Expand Down
4 changes: 2 additions & 2 deletions cel/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) {
if !a.IsChecked() {
return nil, fmt.Errorf("cannot convert unchecked ast")
}
return ast.ToProto(a.impl)
return ast.ToProto(a.NativeRep())
}

// ParsedExprToAst converts a parsed expression proto message to an Ast.
Expand Down Expand Up @@ -99,7 +99,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) {
// Note, the conversion may not be an exact replica of the original expression, but will produce
// a string that is semantically equivalent and whose textual representation is stable.
func AstToString(a *Ast) (string, error) {
return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo())
return parser.Unparse(a.NativeRep().Expr(), a.NativeRep().SourceInfo())
}

// RefValueToValue converts between ref.Val and google.api.expr.v1alpha1.Value.
Expand Down
22 changes: 22 additions & 0 deletions cel/io_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package cel

import (
"fmt"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -144,6 +145,27 @@ func TestAstToString(t *testing.T) {
}
}

func TestAstToStringNil(t *testing.T) {
expr, err := AstToString(nil)
if err == nil || !strings.Contains(err.Error(), "unsupported expr") {
t.Errorf("env.AstToString() got (%v, %v) wanted unsupported expr error", expr, err)
}
}

func TestAstToCheckedExprNil(t *testing.T) {
expr, err := AstToCheckedExpr(nil)
if err == nil || !strings.Contains(err.Error(), "cannot convert unchecked ast") {
t.Errorf("env.AstToCheckedExpr() got (%v, %v) wanted conversion error", expr, err)
}
}

func TestAstToParsedExprNil(t *testing.T) {
expr, err := AstToParsedExpr(nil)
if err != nil {
t.Errorf("env.AstToParsedExpr() got (%v, %v) wanted conversion error", expr, err)
}
}

func TestCheckedExprToAstConstantExpr(t *testing.T) {
stdEnv, err := NewEnv()
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
// If issues are encountered, the Issues.Err() return value will be non-nil.
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// Make a copy of the AST to be optimized.
optimized := ast.Copy(a.impl)
ids := newIDGenerator(ast.MaxID(a.impl))
optimized := ast.Copy(a.NativeRep())
ids := newIDGenerator(ast.MaxID(a.NativeRep()))

// Create the optimizer context, could be pooled in the future.
issues := NewIssues(common.NewErrors(a.Source()))
Expand Down Expand Up @@ -86,7 +86,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
if iss.Err() != nil {
return nil, iss
}
optimized = checked.impl
optimized = checked.NativeRep()
}

// Return the optimized result.
Expand Down
10 changes: 10 additions & 0 deletions cel/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package cel_test

import (
"sort"
"strings"
"testing"

"github.com/google/cel-go/cel"
Expand Down Expand Up @@ -201,6 +202,15 @@ func TestStaticOptimizerNewAST(t *testing.T) {
}
}

func TestStaticOptimizerNilAST(t *testing.T) {
env := optimizerEnv(t)
opt := cel.NewStaticOptimizer(&identityOptimizer{t: t})
optAST, iss := opt.Optimize(env, nil)
if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "unexpected unspecified type") {
t.Errorf("opt.Optimize(env, nil) got (%v, %v), wanted unexpected unspecified type", optAST, iss)
}
}

type identityOptimizer struct {
t *testing.T
}
Expand Down
3 changes: 3 additions & 0 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ type EvalDetails struct {
// State of the evaluation, non-nil if the OptTrackState or OptExhaustiveEval is specified
// within EvalOptions.
func (ed *EvalDetails) State() interpreter.EvalState {
if ed == nil {
return interpreter.NewEvalState()
}
return ed.state
}

Expand Down
6 changes: 5 additions & 1 deletion common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ type Errors struct {

// NewErrors creates a new instance of the Errors type.
func NewErrors(source Source) *Errors {
src := source
if src == nil {
src = NewTextSource("")
}
return &Errors{
errors: []*Error{},
source: source,
source: src,
maxErrorsToReport: 100,
}
}
Expand Down