Skip to content

Commit

Permalink
Add CelOptions to disable string conversion and list/string concatena…
Browse files Browse the repository at this point in the history
…tion

Fixes #502

PiperOrigin-RevId: 700727733
  • Loading branch information
l46kok authored and copybara-github committed Nov 27, 2024
1 parent e104ba7 commit cff6d49
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 4 deletions.
1 change: 1 addition & 0 deletions bundle/src/test/java/dev/cel/bundle/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ java_library(
"//checker:proto_type_mask",
"//common",
"//common:compiler_common",
"//common:error_codes",
"//common:options",
"//common:proto_ast",
"//common/ast",
Expand Down
54 changes: 54 additions & 0 deletions bundle/src/test/java/dev/cel/bundle/CelImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import dev.cel.checker.ProtoTypeMask;
import dev.cel.checker.TypeProvider;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelErrorCode;
import dev.cel.common.CelIssue;
import dev.cel.common.CelOptions;
import dev.cel.common.CelProtoAbstractSyntaxTree;
Expand Down Expand Up @@ -1961,6 +1962,59 @@ public void program_nativeTypeUnknownsEnabled_asCallArguments() throws Exception
assertThat(result.attributes()).isEmpty();
}

@Test
@TestParameters("{expression: 'string(123)'}")
@TestParameters("{expression: 'string(123u)'}")
@TestParameters("{expression: 'string(1.5)'}")
@TestParameters("{expression: 'string(\"foo\")'}")
@TestParameters("{expression: 'string(b\"foo\")'}")
@TestParameters("{expression: 'string(timestamp(100))'}")
@TestParameters("{expression: 'string(duration(\"1h\"))'}")
public void program_stringConversionDisabled_throws(String expression) throws Exception {
Cel cel =
CelFactory.standardCelBuilder()
.setOptions(
CelOptions.current()
.enableTimestampEpoch(true)
.enableStringConversion(false)
.build())
.build();
CelAbstractSyntaxTree ast = cel.compile(expression).getAst();

CelEvaluationException e =
assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval());
assertThat(e).hasMessageThat().contains("No matching overload for function 'string'");
assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.OVERLOAD_NOT_FOUND);
}

@Test
public void program_stringConcatenationDisabled_throws() throws Exception {
Cel cel =
CelFactory.standardCelBuilder()
.setOptions(CelOptions.current().enableStringConcatenation(false).build())
.build();
CelAbstractSyntaxTree ast = cel.compile("'foo' + 'bar'").getAst();

CelEvaluationException e =
assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval());
assertThat(e).hasMessageThat().contains("No matching overload for function '_+_'");
assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.OVERLOAD_NOT_FOUND);
}

@Test
public void program_listConcatenationDisabled_throws() throws Exception {
Cel cel =
CelFactory.standardCelBuilder()
.setOptions(CelOptions.current().enableListConcatenation(false).build())
.build();
CelAbstractSyntaxTree ast = cel.compile("[1] + [2]").getAst();

CelEvaluationException e =
assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval());
assertThat(e).hasMessageThat().contains("No matching overload for function '_+_'");
assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.OVERLOAD_NOT_FOUND);
}

@Test
public void toBuilder_isImmutable() {
CelBuilder celBuilder = CelFactory.standardCelBuilder();
Expand Down
29 changes: 28 additions & 1 deletion common/src/main/java/dev/cel/common/CelOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ public enum ProtoUnsetFieldOptions {

public abstract ProtoUnsetFieldOptions fromProtoUnsetFieldOption();

public abstract boolean enableStringConversion();

public abstract boolean enableStringConcatenation();

public abstract boolean enableListConcatenation();

public abstract Builder toBuilder();

public ImmutableSet<ExprFeatures> toExprFeatures() {
Expand Down Expand Up @@ -200,7 +206,10 @@ public static Builder newBuilder() {
.enableCelValue(false)
.comprehensionMaxIterations(-1)
.unwrapWellKnownTypesOnFunctionDispatch(true)
.fromProtoUnsetFieldOption(ProtoUnsetFieldOptions.BIND_DEFAULT);
.fromProtoUnsetFieldOption(ProtoUnsetFieldOptions.BIND_DEFAULT)
.enableStringConversion(true)
.enableStringConcatenation(true)
.enableListConcatenation(true);
}

/**
Expand Down Expand Up @@ -504,6 +513,24 @@ public abstract static class Builder {
*/
public abstract Builder fromProtoUnsetFieldOption(ProtoUnsetFieldOptions value);

/**
* Enables string() overloads for the runtime. This option exists to maintain parity with
* cel-cpp interpreter options.
*/
public abstract Builder enableStringConversion(boolean value);

/**
* Enables string concatenation overload for the runtime. This option exists to maintain parity
* with cel-cpp interpreter options.
*/
public abstract Builder enableStringConcatenation(boolean value);

/**
* Enables list concatenation overload for the runtime. This option exists to maintain parity
* with cel-cpp interpreter options.
*/
public abstract Builder enableListConcatenation(boolean value);

public abstract CelOptions build();
}
}
16 changes: 13 additions & 3 deletions runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import dev.cel.common.types.CelTypes;
import dev.cel.common.values.CelValueProvider;
import dev.cel.common.values.ProtoMessageValueProvider;
import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Arithmetic;
import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Comparison;
import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Conversions;
import java.util.Arrays;
Expand Down Expand Up @@ -317,13 +318,22 @@ private ImmutableSet<CelFunctionBinding> newStandardFunctionBindings(
return options.enableTimestampEpoch();
}
break;
case STRING:
return options.enableStringConversion();
case ADD:
Arithmetic arithmetic = (Arithmetic) standardOverload;
if (arithmetic.equals(Arithmetic.ADD_STRING)) {
return options.enableStringConcatenation();
}
if (arithmetic.equals(Arithmetic.ADD_LIST)) {
return options.enableListConcatenation();
}
break;
default:
if (standardOverload instanceof Comparison
&& !options.enableHeterogeneousNumericComparisons()) {
Comparison comparison = (Comparison) standardOverload;
if (comparison.isHeterogeneousComparison()) {
return false;
}
return !comparison.isHeterogeneousComparison();
}
break;
}
Expand Down

0 comments on commit cff6d49

Please sign in to comment.