Skip to content

Commit 8fc7387

Browse files
l46kokcopybara-github
authored andcommitted
Optimize composed policies using Constant Folding and Common Subexpression Elimination
PiperOrigin-RevId: 802271809
1 parent aa20b5d commit 8fc7387

File tree

10 files changed

+141
-83
lines changed

10 files changed

+141
-83
lines changed

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
3232
import dev.cel.common.ast.CelMutableExpr;
3333
import dev.cel.common.ast.CelMutableExpr.CelMutableCall;
34+
import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension;
3435
import dev.cel.common.ast.CelMutableExpr.CelMutableList;
3536
import dev.cel.common.ast.CelMutableExpr.CelMutableMap;
3637
import dev.cel.common.ast.CelMutableExpr.CelMutableStruct;
@@ -202,7 +203,11 @@ private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr)
202203
CelNavigableMutableExpr parent = identNode.parent().orElse(null);
203204
while (parent != null) {
204205
if (parent.getKind().equals(Kind.COMPREHENSION)) {
205-
if (parent.expr().comprehension().accuVar().equals(identNode.expr().ident().name())) {
206+
String identName = identNode.expr().ident().name();
207+
CelMutableComprehension parentComprehension = parent.expr().comprehension();
208+
if (parentComprehension.accuVar().equals(identName)
209+
|| parentComprehension.iterVar().equals(identName)
210+
|| parentComprehension.iterVar2().equals(identName)) {
206211
// Prevent folding a subexpression if it contains a variable declared by a
207212
// comprehension. The subexpression cannot be compiled without the full context of the
208213
// surrounding comprehension.

parser/src/main/java/dev/cel/parser/Operator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ static Optional<Operator> find(String text) {
180180
.put(MODULO.getFunction(), "%")
181181
.buildOrThrow();
182182

183-
/** Lookup an operator by its mangled name, as used within the AST. */
183+
/** Lookup an operator by its mangled name (ex: _&&_), as used within the AST. */
184184
public static Optional<Operator> findReverse(String op) {
185185
return Optional.ofNullable(REVERSE_OPERATORS.get(op));
186186
}

policy/src/main/java/dev/cel/policy/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ java_library(
215215
"//optimizer",
216216
"//optimizer:optimization_exception",
217217
"//optimizer:optimizer_builder",
218+
"//optimizer/optimizers:common_subexpression_elimination",
219+
"//optimizer/optimizers:constant_folding",
218220
"//validator",
219221
"//validator:ast_validator",
220222
"//validator:validator_builder",
@@ -247,7 +249,9 @@ java_library(
247249
"//common:cel_ast",
248250
"//common:compiler_common",
249251
"//common:mutable_ast",
252+
"//common/ast",
250253
"//common/formats:value_string",
254+
"//common/navigation:mutable_navigation",
251255
"//extensions:optional_library",
252256
"//optimizer:ast_optimizer",
253257
"//optimizer:mutable_ast",

policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
import dev.cel.optimizer.CelOptimizationException;
3737
import dev.cel.optimizer.CelOptimizer;
3838
import dev.cel.optimizer.CelOptimizerFactory;
39+
import dev.cel.optimizer.optimizers.ConstantFoldingOptimizer;
40+
import dev.cel.optimizer.optimizers.SubexpressionOptimizer;
41+
import dev.cel.optimizer.optimizers.SubexpressionOptimizer.SubexpressionOptimizerOptions;
3942
import dev.cel.policy.CelCompiledRule.CelCompiledMatch;
4043
import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result;
4144
import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result.Kind;
@@ -98,7 +101,7 @@ public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationE
98101
public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledRule)
99102
throws CelPolicyValidationException {
100103
Cel cel = compiledRule.cel();
101-
CelOptimizer optimizer =
104+
CelOptimizer composingOptimizer =
102105
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
103106
.addAstOptimizers(
104107
RuleComposer.newInstance(compiledRule, variablesPrefix, iterationLimit))
@@ -110,7 +113,7 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR
110113
// This is a minimal expression used as a basis of stitching together all the rules into a
111114
// single graph.
112115
ast = cel.compile("true").getAst();
113-
ast = optimizer.optimize(ast);
116+
ast = composingOptimizer.optimize(ast);
114117
} catch (CelValidationException | CelOptimizationException e) {
115118
if (e.getCause() instanceof RuleCompositionException) {
116119
RuleCompositionException re = (RuleCompositionException) e.getCause();
@@ -136,6 +139,28 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR
136139
throw new CelPolicyValidationException("Unexpected error while composing rules.", e);
137140
}
138141

142+
CelOptimizer astOptimizer =
143+
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
144+
.addAstOptimizers(
145+
ConstantFoldingOptimizer.getInstance(),
146+
SubexpressionOptimizer.newInstance(
147+
SubexpressionOptimizerOptions.newBuilder()
148+
// "record" is used for recording subexpression results via
149+
// BlueprintLateFunctionBinding. Safely eliminable, since repeated
150+
// invocation does not change the intermediate results.
151+
.addEliminableFunctions("record")
152+
.populateMacroCalls(true)
153+
.enableCelBlock(true)
154+
.build()))
155+
.build();
156+
try {
157+
// Optimize the composed graph using const fold and CSE
158+
ast = astOptimizer.optimize(ast);
159+
} catch (CelOptimizationException e) {
160+
throw new CelPolicyValidationException(
161+
"Failed to optimize the composed policy. Reason: " + e.getMessage(), e);
162+
}
163+
139164
assertAstDepthIsSafe(ast, cel);
140165

141166
return ast;

policy/src/main/java/dev/cel/policy/RuleComposer.java

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
import static java.util.stream.Collectors.toCollection;
2020

2121
import com.google.auto.value.AutoValue;
22+
import com.google.common.collect.ImmutableList;
2223
import com.google.common.collect.Lists;
2324
import dev.cel.bundle.Cel;
2425
import dev.cel.common.CelAbstractSyntaxTree;
2526
import dev.cel.common.CelMutableAst;
2627
import dev.cel.common.CelValidationException;
28+
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
2729
import dev.cel.common.formats.ValueString;
30+
import dev.cel.common.navigation.CelNavigableMutableAst;
31+
import dev.cel.common.navigation.CelNavigableMutableExpr;
2832
import dev.cel.extensions.CelOptionalLibrary.Function;
2933
import dev.cel.optimizer.AstMutator;
3034
import dev.cel.optimizer.CelAstOptimizer;
@@ -151,23 +155,37 @@ private RuleOptimizationResult optimizeRule(Cel cel, CelCompiledRule compiledRul
151155
}
152156
}
153157

154-
CelMutableAst result = matchAst;
155-
for (CelCompiledVariable variable : Lists.reverse(compiledRule.variables())) {
156-
result =
157-
astMutator.replaceSubtreeWithNewBindMacro(
158-
result,
159-
variablePrefix + variable.name(),
160-
CelMutableAst.fromCelAst(variable.ast()),
161-
result.expr(),
162-
result.expr().id(),
163-
true);
164-
}
158+
CelMutableAst result = inlineCompiledVariables(matchAst, compiledRule.variables());
165159

166160
result = astMutator.renumberIdsConsecutively(result);
167161

168162
return RuleOptimizationResult.create(result, isOptionalResult);
169163
}
170164

165+
private CelMutableAst inlineCompiledVariables(
166+
CelMutableAst ast, List<CelCompiledVariable> compiledVariables) {
167+
CelMutableAst mutatedAst = ast;
168+
for (CelCompiledVariable compiledVariable : Lists.reverse(compiledVariables)) {
169+
String variableName = variablePrefix + compiledVariable.name();
170+
ImmutableList<CelNavigableMutableExpr> exprsToReplace =
171+
CelNavigableMutableAst.fromAst(mutatedAst)
172+
.getRoot()
173+
.allNodes()
174+
.filter(
175+
node ->
176+
node.expr().getKind().equals(Kind.IDENT)
177+
&& node.expr().ident().name().equals(variableName))
178+
.collect(toImmutableList());
179+
180+
for (CelNavigableMutableExpr expr : exprsToReplace) {
181+
CelMutableAst variableAst = CelMutableAst.fromCelAst(compiledVariable.ast());
182+
mutatedAst = astMutator.replaceSubtree(mutatedAst, variableAst, expr.id());
183+
}
184+
}
185+
186+
return mutatedAst;
187+
}
188+
171189
static RuleComposer newInstance(
172190
CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) {
173191
return new RuleComposer(compiledRule, variablePrefix, iterationLimit);

policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ public void compileYamlPolicy_multilineContainsError_throws(
136136

137137
@Test
138138
public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Exception {
139-
String longExpr =
140-
"0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50";
139+
Cel cel = newCel().toCelBuilder().addVar("msg", SimpleType.DYN).build();
140+
String longExpr = "msg.b.c.d.e.f";
141141
String policyContent =
142142
String.format(
143143
"name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr);
@@ -146,11 +146,35 @@ public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Except
146146
CelPolicyValidationException e =
147147
assertThrows(
148148
CelPolicyValidationException.class,
149-
() -> CelPolicyCompilerFactory.newPolicyCompiler(newCel()).build().compile(policy));
149+
() ->
150+
CelPolicyCompilerFactory.newPolicyCompiler(cel)
151+
.setAstDepthLimit(5)
152+
.build()
153+
.compile(policy));
154+
155+
assertThat(e)
156+
.hasMessageThat()
157+
.isEqualTo("ERROR: <input>:-1:0: AST's depth exceeds the configured limit: 5.");
158+
}
150159

160+
@Test
161+
public void compileYamlPolicy_constantFoldingFailure_throwsDuringComposition() throws Exception {
162+
String policyContent =
163+
"name: ast_with_div_by_zero\n" //
164+
+ "rule:\n" //
165+
+ " match:\n" //
166+
+ " - output: 1 / 0";
167+
CelPolicy policy = POLICY_PARSER.parse(policyContent);
168+
169+
CelPolicyValidationException e =
170+
assertThrows(
171+
CelPolicyValidationException.class,
172+
() -> CelPolicyCompilerFactory.newPolicyCompiler(newCel()).build().compile(policy));
151173
assertThat(e)
152174
.hasMessageThat()
153-
.isEqualTo("ERROR: <input>:-1:0: AST's depth exceeds the configured limit: 50.");
175+
.isEqualTo(
176+
"Failed to optimize the composed policy. Reason: Constant folding failure. Failed to"
177+
+ " evaluate subtree due to: evaluation error: / by zero");
154178
}
155179

156180
@Test

policy/src/test/java/dev/cel/policy/PolicyTestHelper.java

Lines changed: 40 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -41,88 +41,68 @@ enum TestYamlPolicy {
4141
NESTED_RULE(
4242
"nested_rule",
4343
true,
44-
"cel.bind(variables.permitted_regions, [\"us\", \"uk\", \"es\"],"
45-
+ " cel.bind(variables.banned_regions, {\"us\": false, \"ru\": false, \"ir\": false},"
46-
+ " (resource.origin in variables.banned_regions && "
47-
+ "!(resource.origin in variables.permitted_regions)) "
48-
+ "? optional.of({\"banned\": true}) : optional.none()).or("
49-
+ "optional.of((resource.origin in variables.permitted_regions)"
50-
+ " ? {\"banned\": false} : {\"banned\": true})))"),
44+
"cel.@block([resource.origin, @index0 in [\"us\", \"uk\", \"es\"], {\"banned\": true}],"
45+
+ " ((@index0 in {\"us\": false, \"ru\": false, \"ir\": false} && !@index1) ?"
46+
+ " optional.of(@index2) : optional.none()).or(optional.of(@index1 ? {\"banned\":"
47+
+ " false} : @index2)))"),
5148
NESTED_RULE2(
5249
"nested_rule2",
5350
false,
54-
"cel.bind(variables.permitted_regions, [\"us\", \"uk\", \"es\"],"
55-
+ " resource.?user.orValue(\"\").startsWith(\"bad\") ?"
56-
+ " cel.bind(variables.banned_regions, {\"us\": false, \"ru\": false, \"ir\": false},"
57-
+ " (resource.origin in variables.banned_regions && !(resource.origin in"
58-
+ " variables.permitted_regions)) ? {\"banned\": \"restricted_region\"} : {\"banned\":"
59-
+ " \"bad_actor\"}) : (!(resource.origin in variables.permitted_regions) ? {\"banned\":"
60-
+ " \"unconfigured_region\"} : {}))"),
51+
"cel.@block([resource.origin, !(@index0 in [\"us\", \"uk\", \"es\"])],"
52+
+ " resource.?user.orValue(\"\").startsWith(\"bad\") ? ((@index0 in {\"us\": false,"
53+
+ " \"ru\": false, \"ir\": false} && @index1) ? {\"banned\": \"restricted_region\"} :"
54+
+ " {\"banned\": \"bad_actor\"}) : (@index1 ? {\"banned\": \"unconfigured_region\"} :"
55+
+ " {}))"),
6156
NESTED_RULE3(
6257
"nested_rule3",
6358
true,
64-
"cel.bind(variables.permitted_regions, [\"us\", \"uk\", \"es\"],"
65-
+ " resource.?user.orValue(\"\").startsWith(\"bad\") ?"
66-
+ " optional.of(cel.bind(variables.banned_regions, {\"us\": false, \"ru\": false,"
67-
+ " \"ir\": false}, (resource.origin in variables.banned_regions && !(resource.origin"
68-
+ " in variables.permitted_regions)) ? {\"banned\": \"restricted_region\"} :"
69-
+ " {\"banned\": \"bad_actor\"})) : (!(resource.origin in variables.permitted_regions)"
70-
+ " ? optional.of({\"banned\": \"unconfigured_region\"}) : optional.none()))"),
59+
"cel.@block([resource.origin, !(@index0 in [\"us\", \"uk\", \"es\"])],"
60+
+ " resource.?user.orValue(\"\").startsWith(\"bad\") ? optional.of((@index0 in {\"us\":"
61+
+ " false, \"ru\": false, \"ir\": false} && @index1) ? {\"banned\":"
62+
+ " \"restricted_region\"} : {\"banned\": \"bad_actor\"}) : (@index1 ?"
63+
+ " optional.of({\"banned\": \"unconfigured_region\"}) : optional.none()))"),
7164
REQUIRED_LABELS(
7265
"required_labels",
7366
true,
74-
""
75-
+ "cel.bind(variables.want, spec.labels, cel.bind(variables.missing, "
76-
+ "variables.want.filter(l, !(l in resource.labels)), cel.bind(variables.invalid, "
77-
+ "resource.labels.filter(l, l in variables.want && variables.want[l] != "
78-
+ "resource.labels[l]), (variables.missing.size() > 0) ? "
79-
+ "optional.of(\"missing one or more required labels: [\"\" + "
80-
+ "variables.missing.join(\",\") + \"\"]\") : ((variables.invalid.size() > 0) ? "
81-
+ "optional.of(\"invalid values provided on one or more labels: [\"\" + "
82-
+ "variables.invalid.join(\",\") + \"\"]\") : optional.none()))))"),
67+
"cel.@block([spec.labels.filter(@it:0:0, !(@it:0:0 in resource.labels)), spec.labels,"
68+
+ " resource.labels, @index2.filter(@it:0:0, @it:0:0 in @index1 && @index1[@it:0:0] !="
69+
+ " @index2[@it:0:0])], (@index0.size() > 0) ? optional.of(\"missing one or more"
70+
+ " required labels: [\"\" + @index0.join(\",\") + \"\"]\") : ((@index3.size() > 0) ?"
71+
+ " optional.of(\"invalid values provided on one or more labels: [\"\" +"
72+
+ " @index3.join(\",\") + \"\"]\") : optional.none()))"),
8373
RESTRICTED_DESTINATIONS(
8474
"restricted_destinations",
8575
false,
86-
"cel.bind(variables.matches_origin_ip, locationCode(origin.ip) == spec.origin,"
87-
+ " cel.bind(variables.has_nationality, has(request.auth.claims.nationality),"
88-
+ " cel.bind(variables.matches_nationality, variables.has_nationality &&"
89-
+ " request.auth.claims.nationality == spec.origin, cel.bind(variables.matches_dest_ip,"
90-
+ " locationCode(destination.ip) in spec.restricted_destinations,"
91-
+ " cel.bind(variables.matches_dest_label, resource.labels.location in"
92-
+ " spec.restricted_destinations, cel.bind(variables.matches_dest,"
93-
+ " variables.matches_dest_ip || variables.matches_dest_label,"
94-
+ " (variables.matches_nationality && variables.matches_dest) ? true :"
95-
+ " ((!variables.has_nationality && variables.matches_origin_ip &&"
96-
+ " variables.matches_dest) ? true : false)))))))"),
76+
"cel.@block([request.auth.claims, has(@index0.nationality), resource.labels.location in"
77+
+ " spec.restricted_destinations], (@index1 && @index0.nationality == spec.origin &&"
78+
+ " (locationCode(destination.ip) in spec.restricted_destinations || @index2)) ? true :"
79+
+ " ((!@index1 && locationCode(origin.ip) == spec.origin &&"
80+
+ " (locationCode(destination.ip) in spec.restricted_destinations || @index2)) ? true :"
81+
+ " false))"),
9782
K8S(
9883
"k8s",
9984
true,
100-
"cel.bind(variables.env, resource.labels.?environment.orValue(\"prod\"),"
101-
+ " cel.bind(variables.break_glass, resource.labels.?break_glass.orValue(\"false\") =="
102-
+ " \"true\", !(variables.break_glass || resource.containers.all(c,"
103-
+ " c.startsWith(variables.env + \".\"))) ? optional.of(\"only \" + variables.env + \""
104-
+ " containers are allowed in namespace \" + resource.namespace) :"
105-
+ " optional.none()))"),
85+
"cel.@block([resource.labels.?environment.orValue(\"prod\")],"
86+
+ " !(resource.labels.?break_glass.orValue(\"false\") == \"true\" ||"
87+
+ " resource.containers.all(@it:0:0, @it:0:0.startsWith(@index0 + \".\"))) ?"
88+
+ " optional.of(\"only \" + @index0 + \" containers are allowed in namespace \" +"
89+
+ " resource.namespace) : optional.none())"),
10690
PB(
10791
"pb",
10892
true,
109-
"(spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64) ? optional.of(\"invalid"
110-
+ " spec, got single_int32=\" + string(spec.single_int32) + \", wanted <= 10\") :"
111-
+ " ((spec.standalone_enum == cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAR"
112-
+ " || dev.cel.testing.testdata.proto3.StandaloneGlobalEnum.SGAR =="
93+
"cel.@block([spec.single_int32], (@index0 > 10) ? optional.of(\"invalid spec, got"
94+
+ " single_int32=\" + string(@index0) + \", wanted <= 10\") : ((spec.standalone_enum =="
95+
+ " cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAR ||"
96+
+ " dev.cel.testing.testdata.proto3.StandaloneGlobalEnum.SGAR =="
11397
+ " dev.cel.testing.testdata.proto3.StandaloneGlobalEnum.SGOO) ? optional.of(\"invalid"
114-
+ " spec, neither nested nor imported enums may refer to BAR\") :"
115-
+ " optional.none())"),
98+
+ " spec, neither nested nor imported enums may refer to BAR\") : optional.none()))"),
11699
LIMITS(
117100
"limits",
118101
true,
119-
"cel.bind(variables.greeting, \"hello\", cel.bind(variables.farewell, \"goodbye\","
120-
+ " cel.bind(variables.person, \"me\", cel.bind(variables.message_fmt, \"%s, %s\","
121-
+ " (now.getHours() >= 20) ? cel.bind(variables.message, variables.farewell + \", \" +"
122-
+ " variables.person, (now.getHours() < 21) ? optional.of(variables.message + \"!\") :"
123-
+ " ((now.getHours() < 22) ? optional.of(variables.message + \"!!\") : ((now.getHours()"
124-
+ " < 24) ? optional.of(variables.message + \"!!!\") : optional.none()))) :"
125-
+ " optional.of(variables.greeting + \", \" + variables.person)))))");
102+
"cel.@block([now.getHours()], (@index0 >= 20) ? ((@index0 < 21) ? optional.of(\"goodbye,"
103+
+ " me!\") : ((@index0 < 22) ? optional.of(\"goodbye, me!!\") : ((@index0 < 24) ?"
104+
+ " optional.of(\"goodbye, me!!!\") : optional.none()))) : optional.of(\"hello,"
105+
+ " me\"))");
126106

127107
private final String name;
128108
private final boolean producesOptionalResult;

0 commit comments

Comments
 (0)