Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for cel.@block during policy composition #1056

Merged
merged 5 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,20 @@ func (e *Env) ParseSource(src Source) (*Ast, *Issues) {

// Program generates an evaluable instance of the Ast within the environment (Env).
func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) {
return e.PlanProgram(ast.NativeRep(), opts...)
}

// PlanProgram generates an evaluable instance of the AST in the go-native representation within
// the environment (Env).
func (e *Env) PlanProgram(a *celast.AST, opts ...ProgramOption) (Program, error) {
optSet := e.progOpts
if len(opts) != 0 {
mergedOpts := []ProgramOption{}
mergedOpts = append(mergedOpts, e.progOpts...)
mergedOpts = append(mergedOpts, opts...)
optSet = mergedOpts
}
return newProgram(e, ast, optSet)
return newProgram(e, a, optSet)
}

// CELTypeAdapter returns the `types.Adapter` configured for the environment.
Expand Down
10 changes: 10 additions & 0 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ type OptimizerContext struct {
*Issues
}

// ExtendEnv auguments the context's environment with the additional options.
func (opt *OptimizerContext) ExtendEnv(opts ...EnvOption) error {
e, err := opt.Env.Extend(opts...)
if err != nil {
return err
}
opt.Env = e
return nil
}

// ASTOptimizer applies an optimization over an AST and returns the optimized result.
type ASTOptimizer interface {
// Optimize optimizes a type-checked AST within an Environment and accumulates any issues.
Expand Down
7 changes: 4 additions & 3 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"sync"

"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
Expand Down Expand Up @@ -151,7 +152,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()

Expand Down Expand Up @@ -255,9 +256,9 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
return p.initInterpretable(a, decorators)
}

func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
func (p *prog) initInterpretable(a *ast.AST, decs []interpreter.InterpretableDecorator) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...)
interpretable, err := p.interpreter.NewInterpretable(a, decs...)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions conformance/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ _ALL_TESTS = [
"@dev_cel_expr//tests/simple:testdata/timestamps.textproto",
"@dev_cel_expr//tests/simple:testdata/unknowns.textproto",
"@dev_cel_expr//tests/simple:testdata/wrappers.textproto",
"@dev_cel_expr//tests/simple:testdata/block_ext.textproto",
]

_TESTS_TO_SKIP = [
Expand Down Expand Up @@ -68,6 +69,7 @@ go_test(
deps = [
"//cel:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//ext:go_default_library",
Expand Down
88 changes: 88 additions & 0 deletions conformance/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/ext"
Expand Down Expand Up @@ -89,6 +90,7 @@ func init() {
ext.Math(),
ext.Protos(),
ext.Strings(),
cel.Lib(celBlockLib{}),
}

var err error
Expand Down Expand Up @@ -279,3 +281,89 @@ func TestConformance(t *testing.T) {
}
}
}

type celBlockLib struct{}

func (celBlockLib) LibraryName() string {
return "cel.lib.ext.cel.block.conformance"
}

func (celBlockLib) CompileOptions() []cel.EnvOption {
// Simulate indexed arguments which would normally have strong types associated
// with the values as part of a static optimization pass
maxIndices := 30
indexOpts := make([]cel.EnvOption, maxIndices)
for i := 0; i < maxIndices; i++ {
indexOpts[i] = cel.Variable(fmt.Sprintf("@index%d", i), cel.DynType)
}
return append([]cel.EnvOption{
cel.Macros(
// cel.block([args], expr)
cel.ReceiverMacro("block", 2, celBlock),
// cel.index(int)
cel.ReceiverMacro("index", 1, celIndex),
// cel.iterVar(int, int)
cel.ReceiverMacro("iterVar", 2, celCompreVar("cel.iterVar", "@it")),
// cel.accuVar(int, int)
cel.ReceiverMacro("accuVar", 2, celCompreVar("cel.accuVar", "@ac")),
),
}, indexOpts...)
}

func (celBlockLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}

func celBlock(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !isCELNamespace(target) {
return nil, nil
}
bindings := args[0]
if bindings.Kind() != ast.ListKind {
return bindings, mef.NewError(bindings.ID(), "cel.block requires the first arg to be a list literal")
}
return mef.NewCall("cel.@block", args...), nil
}

func celIndex(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !isCELNamespace(target) {
return nil, nil
}
index := args[0]
if !isNonNegativeInt(index) {
return index, mef.NewError(index.ID(), "cel.index requires a single non-negative int constant arg")
}
indexVal := index.AsLiteral().(types.Int)
return mef.NewIdent(fmt.Sprintf("@index%d", indexVal)), nil
}

func celCompreVar(funcName, varPrefix string) cel.MacroFactory {
return func(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !isCELNamespace(target) {
return nil, nil
}
depth := args[0]
if !isNonNegativeInt(depth) {
return depth, mef.NewError(depth.ID(), fmt.Sprintf("%s requires two non-negative int constant args", funcName))
}
unique := args[1]
if !isNonNegativeInt(unique) {
return unique, mef.NewError(unique.ID(), fmt.Sprintf("%s requires two non-negative int constant args", funcName))
}
depthVal := depth.AsLiteral().(types.Int)
uniqueVal := unique.AsLiteral().(types.Int)
return mef.NewIdent(fmt.Sprintf("%s:%d:%d", varPrefix, depthVal, uniqueVal)), nil
}
}

func isCELNamespace(target ast.Expr) bool {
return target.Kind() == ast.IdentKind && target.AsIdent() == "cel"
}

func isNonNegativeInt(expr ast.Expr) bool {
if expr.Kind() != ast.LiteralKind {
return false
}
val := expr.AsLiteral()
return val.Type() == cel.IntType && val.(types.Int) >= 0
}
Loading