Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix expression inlining when working with macros #853

Merged
merged 10 commits into from
Nov 10, 2023
36 changes: 17 additions & 19 deletions cel/folding.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error())
return
}
e.SetKindCase(adapted)
ctx.UpdateExpr(e, adapted)
}))

return a
Expand All @@ -134,10 +134,8 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
if err != nil {
return err
}
// Clear any macro metadata associated with the fold.
a.SourceInfo().ClearMacroCall(expr.ID())
// Update the fold expression to be a literal.
expr.SetKindCase(ctx.NewLiteral(out))
ctx.UpdateExpr(expr, ctx.NewLiteral(out))
return nil
}

Expand All @@ -159,15 +157,15 @@ func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
return false
}
if cond.AsLiteral() == types.True {
expr.SetKindCase(truthy)
ctx.UpdateExpr(expr, truthy)
} else {
expr.SetKindCase(falsy)
ctx.UpdateExpr(expr, falsy)
}
return true
case operators.In:
haystack := args[1]
if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
expr.SetKindCase(ctx.NewLiteral(types.False))
ctx.UpdateExpr(expr, ctx.NewLiteral(types.False))
return true
}
needle := args[0]
Expand All @@ -176,7 +174,7 @@ func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
list := haystack.AsList()
for _, e := range list.Elements() {
if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
expr.SetKindCase(ctx.NewLiteral(types.True))
ctx.UpdateExpr(expr, ctx.NewLiteral(types.True))
return true
}
}
Expand All @@ -202,20 +200,20 @@ func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.E
continue
}
if arg.AsLiteral() == shortcircuit {
expr.SetKindCase(arg)
ctx.UpdateExpr(expr, arg)
return true
}
}
if len(newArgs) == 0 {
newArgs = append(newArgs, args[0])
expr.SetKindCase(newArgs[0])
ctx.UpdateExpr(expr, newArgs[0])
return true
}
if len(newArgs) == 1 {
expr.SetKindCase(newArgs[0])
ctx.UpdateExpr(expr, newArgs[0])
return true
}
expr.SetKindCase(ctx.NewCall(function, newArgs...))
ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...))
return true
}

Expand Down Expand Up @@ -270,10 +268,10 @@ func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) {
newOptIndex-- // Skipping causes the list to get smaller.
continue
}
e.SetKindCase(ctx.NewLiteral(optElemVal.GetValue()))
ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue()))
updatedElems = append(updatedElems, e)
}
e.SetKindCase(ctx.NewList(updatedElems, updatedIndices))
ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices))
}

func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
Expand Down Expand Up @@ -303,20 +301,20 @@ func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
if err != nil {
ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err)
}
val.SetKindCase(undoOptVal)
ctx.UpdateExpr(val, undoOptVal)
updatedEntries = append(updatedEntries, e)
continue
}
modified = true
if !optElemVal.HasValue() {
continue
}
val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue()))
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedEntry := ctx.NewMapEntry(key, val, false)
updatedEntries = append(updatedEntries, updatedEntry)
}
if modified {
e.SetKindCase(ctx.NewMap(updatedEntries))
ctx.UpdateExpr(e, ctx.NewMap(updatedEntries))
}
}

Expand All @@ -341,12 +339,12 @@ func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) {
if !optElemVal.HasValue() {
continue
}
val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue()))
ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
updatedField := ctx.NewStructField(field.Name(), val, false)
updatedFields = append(updatedFields, updatedField)
}
if modified {
e.SetKindCase(ctx.NewStruct(s.TypeName(), updatedFields))
ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields))
}
}

Expand Down
119 changes: 95 additions & 24 deletions cel/inlining.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,57 +85,76 @@ func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.A
}

// For a single match, do a direct replacement of the expression sub-graph.
if len(matches) == 1 {
opt.inlineExpr(ctx, matches[0], ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type())
continue
}

if !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
if len(matches) == 1 || !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) {
for _, match := range matches {
opt.inlineExpr(ctx, match, ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type())
// Copy the inlined AST expr and source info.
copyExpr := copyASTAndMetadata(ctx, inlineVar.def)
opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type())
}
continue
}

// For multiple matches, find the least common ancestor (lca) and insert the
// variable as a cel.bind() macro.
var lca ast.NavigableExpr = nil
ancestors := map[int64]bool{}
var lca ast.NavigableExpr = root
lcaAncestorCount := 0
ancestors := map[int64]int{}
for _, match := range matches {
// Update the identifier matches with the provided alias.
aliasExpr := ctx.NewIdent(inlineVar.Alias())
opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type())
parent, found := match, true
for found {
_, hasAncestor := ancestors[parent.ID()]
if hasAncestor && (lca == nil || lca.Depth() < parent.Depth()) {
ancestorCount, hasAncestor := ancestors[parent.ID()]
if !hasAncestor {
ancestors[parent.ID()] = 1
parent, found = parent.Parent()
continue
}
if lcaAncestorCount < ancestorCount || (lcaAncestorCount == ancestorCount && lca.Depth() < parent.Depth()) {
lca = parent
lcaAncestorCount = ancestorCount
}
ancestors[parent.ID()] = true
ancestors[parent.ID()] = ancestorCount + 1
parent, found = parent.Parent()
}
aliasExpr := ctx.NewIdent(inlineVar.Alias())
opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type())
}

// Copy the inlined AST expr and source info.
copyExpr := copyASTAndMetadata(ctx, inlineVar.def)
// Update the least common ancestor by inserting a cel.bind() call to the alias.
inlined := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), inlineVar.Expr(), lca)
inlined, bindMacro := opt.celBindMacro(ctx, lca.ID(), inlineVar.Alias(), copyExpr, lca)
opt.inlineExpr(ctx, lca, inlined, inlineVar.Type())
ctx.sourceInfo.SetMacroCall(lca.ID(), bindMacro)
}
return a
}

// copyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being
// optimized.
func copyASTAndMetadata(ctx *OptimizerContext, a *ast.AST) ast.Expr {
copyExpr, copyInfo := ctx.CopyAST(a)
// Add in the macro calls from the inlined AST
for id, call := range copyInfo.MacroCalls() {
ctx.sourceInfo.SetMacroCall(id, call)
}
return copyExpr
}

// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining
// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is
// made to determine whether the inlined value can be presence or existence tested.
func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) {
switch prev.Kind() {
case ast.SelectKind:
sel := prev.AsSelect()
if !sel.IsTestOnly() {
prev.SetKindCase(inlined)
ctx.UpdateExpr(prev, inlined)
return
}
opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType)
default:
prev.SetKindCase(inlined)
ctx.UpdateExpr(prev, inlined)
}
}

Expand All @@ -146,23 +165,24 @@ func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev, inlined as
func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) {
// If the input inlined expression is not a select expression it won't work with the has()
// macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error.
ctx.sourceInfo.ClearMacroCall(prev.ID())
if inlined.Kind() == ast.SelectKind {
inlinedSel := inlined.AsSelect()
prev.SetKindCase(
ctx.NewPresenceTest(prev.ID(), inlinedSel.Operand(), inlinedSel.FieldName()))
presenceTest, hasMacro := opt.hasMacro(ctx, prev.ID(), inlined)
ctx.UpdateExpr(prev, presenceTest)
ctx.sourceInfo.SetMacroCall(prev.ID(), hasMacro)
return
}

ctx.sourceInfo.ClearMacroCall(prev.ID())
if inlinedType.IsAssignableType(NullType) {
prev.SetKindCase(
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
inlined,
ctx.NewLiteral(types.NullValue),
))
return
}
if inlinedType.HasTrait(traits.SizerType) {
prev.SetKindCase(
ctx.UpdateExpr(prev,
ctx.NewCall(operators.NotEquals,
ctx.NewMemberCall(overloads.Size, inlined),
ctx.NewLiteral(types.IntZero),
Expand Down Expand Up @@ -218,3 +238,54 @@ func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher {
return false
}
}

func (opt *inliningOptimizer) hasMacro(ctx *OptimizerContext, macroID int64, s ast.Expr) (astExpr, macroExpr ast.Expr) {
sel := s.AsSelect()
astExpr = ctx.fac.NewPresenceTest(macroID, sel.Operand(), sel.FieldName())
macroExpr = ctx.fac.NewCall(0, "has",
ctx.NewSelect(ctx.fac.CopyExpr(sel.Operand()), sel.FieldName()))
sanitizeMacro(ctx, macroID, macroExpr)
return
}

func (opt *inliningOptimizer) celBindMacro(
ctx *OptimizerContext, macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) {
varID := ctx.nextID()
remainingID := ctx.nextID()
remaining = ctx.fac.CopyExpr(remaining)
remaining.RenumberIDs(func(id int64) int64 {
if id == macroID {
return remainingID
}
return id
})
if call, exists := ctx.sourceInfo.GetMacroCall(macroID); exists {
ctx.sourceInfo.SetMacroCall(remainingID, ctx.fac.CopyExpr(call))
}

astExpr = ctx.fac.NewComprehension(macroID,
ctx.fac.NewList(ctx.nextID(), []ast.Expr{}, []int32{}),
"#unused",
varName,
ctx.fac.CopyExpr(varInit),
ctx.fac.NewLiteral(ctx.nextID(), types.False),
ctx.fac.NewIdent(varID, varName),
remaining)

macroExpr = ctx.fac.NewMemberCall(0, "bind",
ctx.fac.NewIdent(ctx.nextID(), "cel"),
ctx.fac.NewIdent(varID, varName),
ctx.fac.CopyExpr(varInit),
ctx.fac.CopyExpr(remaining))
sanitizeMacro(ctx, macroID, macroExpr)
return
}

func sanitizeMacro(ctx *OptimizerContext, macroID int64, macroExpr ast.Expr) {
macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) {
if _, exists := ctx.sourceInfo.GetMacroCall(e.ID()); exists && e.ID() != macroID {
e.SetKindCase(nil)
}
})
ast.PostOrderVisit(macroExpr, macroRefVisitor)
}
Loading