diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 340353259..d7727d6a7 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -445,6 +445,7 @@ cc_test( "//eval/public:builtin_func_registrar", "//eval/public:cel_function", "//eval/public:cel_function_registry", + "//eval/public:cel_value", "//extensions/protobuf:ast_converters", "//internal:casts", "//internal:proto_matchers", @@ -499,7 +500,6 @@ cc_test( "//eval/testutil:test_message_cc_proto", "//internal:testing", "@com_google_absl//absl/status", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], @@ -542,16 +542,20 @@ cc_test( ":flat_expr_builder", ":flat_expr_builder_extensions", ":regex_precompilation_optimization", + ":resolver", "//common/ast:ast_impl", "//eval/eval:evaluator_core", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expression", + "//eval/public:cel_function_registry", "//eval/public:cel_options", + "//eval/public:cel_type_registry", "//eval/public:cel_value", "//internal:testing", "//parser", "//runtime:runtime_issue", + "//runtime:runtime_options", "//runtime/internal:issue_collector", "//runtime/internal:runtime_env", "//runtime/internal:runtime_env_testing", diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index 7b20f3227..bc9463890 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -16,6 +16,7 @@ #include #include +#include #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" @@ -78,8 +79,7 @@ class UpdatedConstantFoldingTest : public testing::Test { type_registry_(env_->type_registry), issue_collector_(RuntimeIssue::Severity::kError), resolver_("", function_registry_, type_registry_, - type_registry_.GetComposedTypeProvider(), - type_registry_.resolveable_enums()) {} + type_registry_.GetComposedTypeProvider()) {} protected: absl::Nonnull> env_; diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index d66037a50..80a6c848d 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -2467,12 +2467,17 @@ std::vector FlattenExpressionTable( absl::StatusOr FlatExprBuilder::CreateExpressionImpl( std::unique_ptr ast, std::vector* issues) const { + if (absl::StartsWith(container_, ".") || absl::EndsWith(container_, ".")) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid expression container: '", container_, "'")); + } + RuntimeIssue::Severity max_severity = options_.fail_on_warnings ? RuntimeIssue::Severity::kWarning : RuntimeIssue::Severity::kError; IssueCollector issue_collector(max_severity); Resolver resolver(container_, function_registry_, type_registry_, - GetTypeProvider(), type_registry_.resolveable_enums(), + GetTypeProvider(), options_.enable_qualified_type_identifiers); std::shared_ptr arena; @@ -2482,11 +2487,6 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( auto& ast_impl = AstImpl::CastFromPublicAst(*ast); - if (absl::StartsWith(container_, ".") || absl::EndsWith(container_, ".")) { - return absl::InvalidArgumentError( - absl::StrCat("Invalid expression container: '", container_, "'")); - } - for (const std::unique_ptr& transform : ast_transforms_) { CEL_RETURN_IF_ERROR(transform->UpdateAst(extension_context, ast_impl)); } diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index f0263f065..5427b00ec 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -23,10 +23,12 @@ #include #include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "base/type_provider.h" +#include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" #include "runtime/function_registry.h" @@ -98,6 +100,7 @@ class FlatExprBuilder { const absl::Nonnull> env_; + cel::RuntimeOptions options_; std::string container_; bool enable_optional_types_ = false; diff --git a/eval/compiler/flat_expr_builder_extensions_test.cc b/eval/compiler/flat_expr_builder_extensions_test.cc index 0cfea503a..a8fe5a3b6 100644 --- a/eval/compiler/flat_expr_builder_extensions_test.cc +++ b/eval/compiler/flat_expr_builder_extensions_test.cc @@ -15,6 +15,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/status/status.h" @@ -62,8 +63,7 @@ class PlannerContextTest : public testing::Test { type_registry_(env_->type_registry), function_registry_(env_->function_registry), resolver_("", function_registry_, type_registry_, - type_registry_.GetComposedTypeProvider(), - type_registry_.resolveable_enums()), + type_registry_.GetComposedTypeProvider()), issue_collector_(RuntimeIssue::Severity::kError) {} protected: diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 5c985f3bb..8020d940c 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -154,8 +154,9 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK( - builder.GetRegistry()->Register(std::make_unique())); + ASSERT_THAT( + builder.GetRegistry()->Register(std::make_unique()), + IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -342,14 +343,16 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { int count1 = 0; int count2 = 0; - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("recorder1", &count1))); - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("recorder2", &count2))); + ASSERT_THAT(builder.GetRegistry()->Register( + std::make_unique("recorder1", &count1)), + IsOk()); + ASSERT_THAT(builder.GetRegistry()->Register( + std::make_unique("recorder2", &count2)), + IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr_on, builder.CreateExpression(&expr, &source_info)); - ASSERT_OK(cel_expr_on->Evaluate(activation, &arena)); + ASSERT_THAT(cel_expr_on->Evaluate(activation, &arena), IsOk()); EXPECT_THAT(count1, Eq(1)); EXPECT_THAT(count2, Eq(0)); @@ -365,15 +368,17 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { int count1 = 0; int count2 = 0; - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("recorder1", &count1))); - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("recorder2", &count2))); + ASSERT_THAT(builder.GetRegistry()->Register( + std::make_unique("recorder1", &count1)), + IsOk()); + ASSERT_THAT(builder.GetRegistry()->Register( + std::make_unique("recorder2", &count2)), + IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr_off, builder.CreateExpression(&expr, &source_info)); - ASSERT_OK(cel_expr_off->Evaluate(activation, &arena)); + ASSERT_THAT(cel_expr_off->Evaluate(activation, &arena), IsOk()); EXPECT_THAT(count1, Eq(1)); EXPECT_THAT(count2, Eq(1)); } @@ -411,13 +416,15 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count = 0; - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("recorder_function1", &count))); + ASSERT_THAT( + builder.GetRegistry()->Register( + std::make_unique("recorder_function1", &count)), + IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr_on, builder.CreateExpression(&expr, &source_info)); - ASSERT_OK(cel_expr_on->Evaluate(activation, &arena)); + ASSERT_THAT(cel_expr_on->Evaluate(activation, &arena), IsOk()); EXPECT_THAT(count, Eq(0)); } @@ -429,11 +436,13 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { auto builtin = RegisterBuiltinFunctions(builder.GetRegistry()); int count = 0; - ASSERT_OK(builder.GetRegistry()->Register( - std::make_unique("recorder_function1", &count))); + ASSERT_THAT( + builder.GetRegistry()->Register( + std::make_unique("recorder_function1", &count)), + IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr_off, builder.CreateExpression(&expr, &source_info)); - ASSERT_OK(cel_expr_off->Evaluate(activation, &arena)); + ASSERT_THAT(cel_expr_off->Evaluate(activation, &arena), IsOk()); EXPECT_THAT(count, Eq(3)); } } @@ -445,7 +454,7 @@ TEST(FlatExprBuilderTest, IdentExprUnsetName) { google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'name' must not be empty"))); @@ -461,7 +470,7 @@ TEST(FlatExprBuilderTest, SelectExprUnsetField) { &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'field' must not be empty"))); @@ -490,7 +499,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { // An empty ident without the name set should error. google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'accu_var' must not be empty"))); @@ -505,7 +514,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'iter_var' must not be empty"))); @@ -522,7 +531,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'accu_init' must be set"))); @@ -542,7 +551,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'loop_condition' must be set"))); @@ -565,7 +574,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'loop_step' must be set"))); @@ -591,7 +600,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { )", &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("'result' must be set"))); @@ -641,7 +650,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -673,7 +682,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); builder.set_container(".bad"); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -904,7 +913,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; @@ -966,7 +975,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; @@ -1036,7 +1045,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); builder.set_container("com.foo"); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK((FunctionAdapter::CreateAndRegister( "com.foo.ext.and", false, [](google::protobuf::Arena*, bool lhs, bool rhs) { return lhs && rhs; }, @@ -1103,7 +1112,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( NewReferenceResolverExtension(ReferenceResolverOption::kCheckedOnly)); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; @@ -1171,7 +1180,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { google::protobuf::Arena arena; builder.flat_expr_builder().AddProgramOptimizer( cel::runtime_internal::CreateConstantFoldingOptimizer()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr)); Activation activation; @@ -1254,7 +1263,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1325,7 +1334,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { &expr); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1377,7 +1386,7 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { cel::RuntimeOptions options; options.comprehension_max_iterations = 1; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto cel_expr, builder.CreateExpression(&expr, &source_info)); @@ -1445,24 +1454,33 @@ TEST(FlatExprBuilderTest, ContainerStringFormat) { SourceInfo source_info; expr.mutable_ident_expr()->set_name("ident"); - CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - builder.set_container(""); - ASSERT_OK(builder.CreateExpression(&expr, &source_info)); - - builder.set_container("random.namespace"); - ASSERT_OK(builder.CreateExpression(&expr, &source_info)); - - // Leading '.' - builder.set_container(".random.namespace"); - EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Invalid expression container"))); + { + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.set_container(""); + ASSERT_THAT(builder.CreateExpression(&expr, &source_info), IsOk()); + } - // Trailing '.' - builder.set_container("random.namespace."); - EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Invalid expression container"))); + { + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + builder.set_container("random.namespace"); + ASSERT_THAT(builder.CreateExpression(&expr, &source_info), IsOk()); + } + { + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + // Leading '.' + builder.set_container(".random.namespace"); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid expression container"))); + } + { + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + // Trailing '.' + builder.set_container("random.namespace."); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid expression container"))); + } } void EvalExpressionWithEnum(absl::string_view enum_name, @@ -1493,7 +1511,7 @@ void EvalExpressionWithEnum(absl::string_view enum_name, google::protobuf::Arena arena; Activation activation; auto eval = cel_expr->Evaluate(activation, &arena); - ASSERT_OK(eval); + ASSERT_THAT(eval, IsOk()); *result = eval.value(); } @@ -1691,22 +1709,25 @@ TEST(FlatExprBuilderTest, Ternary) { // On True, value 1 { CelValue result; - ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(true), - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(true), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(1)); // Unknown handling UnknownSet unknown_set; - ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(true), - CelValue::CreateUnknownSet(&unknown_set), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(true), + CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsUnknownSet()); - ASSERT_OK(RunTernaryExpression( - CelValue::CreateBool(true), CelValue::CreateInt64(1), - CelValue::CreateUnknownSet(&unknown_set), &arena, &result)); + ASSERT_THAT(RunTernaryExpression( + CelValue::CreateBool(true), CelValue::CreateInt64(1), + CelValue::CreateUnknownSet(&unknown_set), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(1)); } @@ -1714,40 +1735,45 @@ TEST(FlatExprBuilderTest, Ternary) { // On False, value 2 { CelValue result; - ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(false), - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(false), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(2)); // Unknown handling UnknownSet unknown_set; - ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(false), - CelValue::CreateUnknownSet(&unknown_set), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateBool(false), + CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsInt64()); EXPECT_THAT(result.Int64OrDie(), Eq(2)); - ASSERT_OK(RunTernaryExpression( - CelValue::CreateBool(false), CelValue::CreateInt64(1), - CelValue::CreateUnknownSet(&unknown_set), &arena, &result)); + ASSERT_THAT(RunTernaryExpression( + CelValue::CreateBool(false), CelValue::CreateInt64(1), + CelValue::CreateUnknownSet(&unknown_set), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsUnknownSet()); } // On Error, surface error { CelValue result; - ASSERT_OK(RunTernaryExpression(CreateErrorValue(&arena, "error"), - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CreateErrorValue(&arena, "error"), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsError()); } // On Unknown, surface Unknown { UnknownSet unknown_set; CelValue result; - ASSERT_OK(RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_set), - CelValue::CreateInt64(1), - CelValue::CreateInt64(2), &arena, &result)); + ASSERT_THAT(RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result), + IsOk()); ASSERT_TRUE(result.IsUnknownSet()); EXPECT_THAT(unknown_set, Eq(*result.UnknownSetOrDie())); } @@ -1763,10 +1789,12 @@ TEST(FlatExprBuilderTest, Ternary) { UnknownSet unknown_value1(UnknownAttributeSet({value1_attr})); UnknownSet unknown_value2(UnknownAttributeSet({value2_attr})); CelValue result; - ASSERT_OK(RunTernaryExpression( - CelValue::CreateUnknownSet(&unknown_selector), - CelValue::CreateUnknownSet(&unknown_value1), - CelValue::CreateUnknownSet(&unknown_value2), &arena, &result)); + ASSERT_THAT( + RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_selector), + CelValue::CreateUnknownSet(&unknown_value1), + CelValue::CreateUnknownSet(&unknown_value2), + &arena, &result), + IsOk()); ASSERT_TRUE(result.IsUnknownSet()); const UnknownSet* result_set = result.UnknownSetOrDie(); EXPECT_THAT(result_set->unknown_attributes().size(), Eq(1)); @@ -1783,7 +1811,7 @@ TEST(FlatExprBuilderTest, EmptyCallList) { auto call_expr = expr.mutable_call_expr(); call_expr->set_function(op); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); auto build = builder.CreateExpression(&expr, &source_info); ASSERT_FALSE(build.ok()); } @@ -1849,7 +1877,7 @@ TEST(FlatExprBuilderTest, TypeResolve) { options.enable_qualified_type_identifiers = true; CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); builder.set_container("google.api.expr"); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); ASSERT_OK_AND_ASSIGN(auto expression, builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info())); @@ -2130,11 +2158,11 @@ TEST_P(CustomDescriptorPoolTest, TestType) { google::protobuf::Arena arena; // Setup descriptor pool and builder - ASSERT_OK(AddStandardMessageTypesToDescriptorPool(descriptor_pool)); + ASSERT_THAT(AddStandardMessageTypesToDescriptorPool(descriptor_pool), IsOk()); google::protobuf::DynamicMessageFactory message_factory(&descriptor_pool); ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse("m")); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); - ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); // Create test subject, invoke custom setter for message auto [message, reflection] = diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index f69c798af..aa9518ae2 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -34,6 +34,7 @@ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" #include "extensions/protobuf/ast_converters.h" #include "internal/casts.h" #include "internal/proto_matchers.h" @@ -137,8 +138,7 @@ TEST(ResolveReferences, Basic) { CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); @@ -164,8 +164,7 @@ TEST(ResolveReferences, ReturnsFalseIfNoChanges) { CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); @@ -185,8 +184,7 @@ TEST(ResolveReferences, NamespacedIdent) { CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[2].set_name("foo.bar.var1"); expr_ast->reference_map()[7].set_name("namespace_x.bar"); @@ -243,8 +241,7 @@ TEST(ResolveReferences, WarningOnPresenceTest) { CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[1].set_name("foo.bar.var1"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -293,8 +290,7 @@ TEST(ResolveReferences, EnumConstReferenceUsed) { ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[2].set_name("foo.bar.var1"); expr_ast->reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); expr_ast->reference_map()[5].mutable_value().set_int64_value(9); @@ -327,8 +323,7 @@ TEST(ResolveReferences, EnumConstReferenceUsedSelect) { ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[2].set_name("foo.bar.var1"); expr_ast->reference_map()[2].mutable_value().set_int64_value(2); expr_ast->reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); @@ -361,8 +356,7 @@ TEST(ResolveReferences, ConstReferenceSkipped) { ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[2].set_name("foo.bar.var1"); expr_ast->reference_map()[2].mutable_value().set_bool_value(true); expr_ast->reference_map()[5].set_name("bar.foo.var2"); @@ -430,8 +424,7 @@ TEST(ResolveReferences, FunctionReferenceBasic) { }))); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); expr_ast->reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); @@ -448,8 +441,7 @@ TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); expr_ast->reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); @@ -485,8 +477,7 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); expr_ast->reference_map()[1].mutable_overload_id().push_back( absl::StrCat("builtin.", builtin_fn)); @@ -507,8 +498,7 @@ TEST(ResolveReferences, CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); expr_ast->reference_map()[1].set_name("udf_boolean_and"); @@ -531,8 +521,7 @@ TEST(ResolveReferences, EmulatesEagerFailing) { CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kWarning); expr_ast->reference_map()[1].set_name("udf_boolean_and"); @@ -550,8 +539,7 @@ TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[2].mutable_overload_id().push_back( "udf_boolean_and"); @@ -591,8 +579,7 @@ TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); @@ -612,8 +599,7 @@ TEST(ResolveReferences, CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); @@ -635,8 +621,7 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { "ext.boolean_and", false, {CelValue::Type::kBool}))); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); @@ -669,9 +654,9 @@ TEST(ResolveReferences, ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "com.google.ext.boolean_and", false, {CelValue::Type::kBool}))); cel::TypeRegistry type_registry; + std::vector namespace_prefixes{"com.google.", "google.", ""}; Resolver registry("com.google", func_registry.InternalGetRegistry(), - type_registry, type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry, type_registry.GetComposedTypeProvider()); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); @@ -728,8 +713,7 @@ TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { "ext.option.boolean_and", true, {CelValue::Type::kBool}))); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); @@ -818,8 +802,7 @@ TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[3].set_name("ENUM"); expr_ast->reference_map()[3].mutable_value().set_int64_value(2); expr_ast->reference_map()[7].set_name("ENUM"); @@ -921,8 +904,7 @@ TEST(ResolveReferences, ReferenceToId0Warns) { ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, - type_registry.GetComposedTypeProvider(), - type_registry.resolveable_enums()); + type_registry.GetComposedTypeProvider()); expr_ast->reference_map()[0].set_name("pkg.var"); IssueCollector issues(RuntimeIssue::Severity::kError); diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc index c0587d5c1..9e05b41d3 100644 --- a/eval/compiler/regex_precompilation_optimization_test.cc +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -28,11 +28,14 @@ #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/flat_expr_builder_extensions.h" +#include "eval/compiler/resolver.h" #include "eval/eval/evaluator_core.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expression.h" +#include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" #include "internal/testing.h" #include "parser/parser.h" @@ -40,6 +43,7 @@ #include "runtime/internal/runtime_env.h" #include "runtime/internal/runtime_env_testing.h" #include "runtime/runtime_issue.h" +#include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -63,8 +67,7 @@ class RegexPrecompilationExtensionTest : public testing::TestWithParam { function_registry_(*builder_.GetRegistry()), resolver_("", function_registry_.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()), + type_registry_.GetTypeProvider()), issue_collector_(RuntimeIssue::Severity::kError) { if (EnableRecursivePlanning()) { options_.max_recursion_depth = -1; diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index d63067257..95388d95a 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -28,7 +28,6 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "absl/strings/strip.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/kind.h" @@ -41,58 +40,41 @@ #include "runtime/type_registry.h" namespace google::api::expr::runtime { +namespace { -using ::cel::IntValue; using ::cel::TypeValue; using ::cel::Value; +using ::cel::runtime_internal::GetEnumValueTable; -Resolver::Resolver( - absl::string_view container, const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry&, const cel::TypeReflector& type_reflector, - const absl::flat_hash_map& - resolveable_enums, - bool resolve_qualified_type_identifiers) - : namespace_prefixes_(), - enum_value_map_(), - function_registry_(function_registry), - type_reflector_(type_reflector), - resolveable_enums_(resolveable_enums), - resolve_qualified_type_identifiers_(resolve_qualified_type_identifiers) { - // The constructor for the registry determines the set of possible namespace - // prefixes which may appear within the given expression container, and also - // eagerly maps possible enum names to enum values. - - auto container_elements = absl::StrSplit(container, '.'); +std::vector MakeNamespaceCandidates(absl::string_view container) { + std::vector namespace_prefixes; std::string prefix = ""; - namespace_prefixes_.push_back(prefix); + namespace_prefixes.push_back(prefix); + auto container_elements = absl::StrSplit(container, '.'); for (const auto& elem : container_elements) { // Tolerate trailing / leading '.'. if (elem.empty()) { continue; } absl::StrAppend(&prefix, elem, "."); - namespace_prefixes_.insert(namespace_prefixes_.begin(), prefix); + // longest prefix first. + namespace_prefixes.insert(namespace_prefixes.begin(), prefix); } + return namespace_prefixes; +} - for (const auto& prefix : namespace_prefixes_) { - for (auto iter = resolveable_enums_.begin(); - iter != resolveable_enums_.end(); ++iter) { - absl::string_view enum_name = iter->first; - if (!absl::StartsWith(enum_name, prefix)) { - continue; - } - - auto remainder = absl::StripPrefix(enum_name, prefix); - const auto& enum_type = iter->second; +} // namespace - for (const auto& enumerator : enum_type.enumerators) { - auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", - enumerator.name); - enum_value_map_[key] = IntValue(enumerator.number); - } - } - } -} +Resolver::Resolver(absl::string_view container, + const cel::FunctionRegistry& function_registry, + const cel::TypeRegistry& type_registry, + const cel::TypeReflector& type_reflector, + bool resolve_qualified_type_identifiers) + : namespace_prefixes_(MakeNamespaceCandidates(container)), + enum_value_map_(GetEnumValueTable(type_registry)), + function_registry_(function_registry), + type_reflector_(type_reflector), + resolve_qualified_type_identifiers_(resolve_qualified_type_identifiers) {} std::vector Resolver::FullyQualifiedNames(absl::string_view name, int64_t expr_id) const { @@ -102,6 +84,7 @@ std::vector Resolver::FullyQualifiedNames(absl::string_view name, std::vector names; auto prefixes = GetPrefixesFor(name); + names.reserve(prefixes.size()); for (const auto& prefix : prefixes) { std::string fully_qualified_name = absl::StrCat(prefix, name); names.push_back(fully_qualified_name); @@ -125,15 +108,12 @@ absl::optional Resolver::FindConstant(absl::string_view name, for (const auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); // Attempt to resolve the fully qualified name to a known enum. - auto enum_entry = enum_value_map_.find(qualified_name); - if (enum_entry != enum_value_map_.end()) { + auto enum_entry = enum_value_map_->find(qualified_name); + if (enum_entry != enum_value_map_->end()) { return enum_entry->second; } - // Conditionally resolve fully qualified names as type values if the option - // to do so is configured in the expression builder. If the type name is - // not qualified, then it too may be returned as a constant value. - if (resolve_qualified_type_identifiers_ || - !absl::StrContains(qualified_name, ".")) { + // Attempt to resolve the fully qualified name to a known type. + if (resolve_qualified_type_identifiers_) { auto type_value = type_reflector_.FindType(qualified_name); if (type_value.ok() && type_value->has_value()) { return TypeValue(**type_value); @@ -141,6 +121,13 @@ absl::optional Resolver::FindConstant(absl::string_view name, } } + if (!resolve_qualified_type_identifiers_ && !absl::StrContains(name, '.')) { + auto type_value = type_reflector_.FindType(name); + + if (type_value.ok() && type_value->has_value()) { + return TypeValue(**type_value); + } + } return absl::nullopt; } diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h index c36fcafb9..fe30c2dd6 100644 --- a/eval/compiler/resolver.h +++ b/eval/compiler/resolver.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -27,6 +28,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/kind.h" +#include "common/type.h" #include "common/type_reflector.h" #include "common/value.h" #include "runtime/function_overload_reference.h" @@ -35,24 +37,23 @@ namespace google::api::expr::runtime { -// Resolver assists with finding functions and types within a container. +// Resolver assists with finding functions and types from the associated +// registries within a container. // -// This class builds on top of the cel::FunctionRegistry and cel::TypeRegistry -// by layering on the namespace resolution rules of CEL onto the calls provided -// by each of these libraries. -// -// TODO: refactor the Resolver to consider CheckedExpr metadata -// for reference resolution. +// container is used to construct the namespace lookup candidates. +// e.g. for "cel.dev" -> {"cel.dev.", "cel.", ""} class Resolver { public: - Resolver( - absl::string_view container, - const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry& type_registry, - const cel::TypeReflector& type_reflector, - const absl::flat_hash_map& - resolveable_enums, - bool resolve_qualified_type_identifiers = true); + Resolver(absl::string_view container, + const cel::FunctionRegistry& function_registry, + const cel::TypeRegistry& type_registry, + const cel::TypeReflector& type_reflector, + bool resolve_qualified_type_identifiers = true); + + Resolver(const Resolver&) = delete; + Resolver& operator=(const Resolver&) = delete; + Resolver(Resolver&&) = delete; + Resolver& operator=(Resolver&&) = delete; ~Resolver() = default; @@ -100,11 +101,10 @@ class Resolver { absl::Span GetPrefixesFor(absl::string_view& name) const; std::vector namespace_prefixes_; - absl::flat_hash_map enum_value_map_; + std::shared_ptr> + enum_value_map_; const cel::FunctionRegistry& function_registry_; const cel::TypeReflector& type_reflector_; - const absl::flat_hash_map& - resolveable_enums_; bool resolve_qualified_type_identifiers_; }; @@ -112,7 +112,7 @@ class Resolver { // ArgumentMatcher generates a function signature matcher for CelFunctions. // TODO: this is the same behavior as parsed exprs in the CPP // evaluator (just check the right call style and number of arguments), but we -// should have enough type information in a checked expr to find a more +// should have enough type information in a checked expr to find a more // specific candidate list. inline std::vector ArgumentsMatcher(int argument_count) { std::vector argument_matcher(argument_count); diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc index 6e301ea39..212790b22 100644 --- a/eval/compiler/resolver_test.cc +++ b/eval/compiler/resolver_test.cc @@ -19,7 +19,6 @@ #include #include "absl/status/status.h" -#include "absl/types/optional.h" #include "absl/types/span.h" #include "common/value.h" #include "eval/public/cel_function.h" @@ -29,6 +28,7 @@ #include "eval/testutil/test_message.pb.h" #include "internal/testing.h" #include "google/protobuf/arena.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { @@ -61,8 +61,7 @@ TEST_F(ResolverTest, TestFullyQualifiedNames) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()); + type_registry_.GetTypeProvider()); auto names = resolver.FullyQualifiedNames("simple_name"); std::vector expected_names( @@ -75,8 +74,7 @@ TEST_F(ResolverTest, TestFullyQualifiedNamesPartiallyQualifiedName) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()); + type_registry_.GetTypeProvider()); auto names = resolver.FullyQualifiedNames("expr.simple_name"); std::vector expected_names( @@ -89,8 +87,7 @@ TEST_F(ResolverTest, TestFullyQualifiedNamesAbsoluteName) { CelFunctionRegistry func_registry; Resolver resolver("google.api.expr", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()); + type_registry_.GetTypeProvider()); auto names = resolver.FullyQualifiedNames(".google.api.expr.absolute_name"); EXPECT_THAT(names.size(), Eq(1)); @@ -104,8 +101,7 @@ TEST_F(ResolverTest, TestFindConstantEnum) { Resolver resolver("google.api.expr.runtime.TestMessage", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()); + type_registry_.GetTypeProvider()); auto enum_value = resolver.FindConstant("TestEnum.TEST_ENUM_1", -1); ASSERT_TRUE(enum_value); @@ -123,8 +119,7 @@ TEST_F(ResolverTest, TestFindConstantUnqualifiedType) { CelFunctionRegistry func_registry; Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()); + type_registry_.GetTypeProvider()); auto type_value = resolver.FindConstant("int", -1); EXPECT_TRUE(type_value); @@ -137,8 +132,7 @@ TEST_F(ResolverTest, TestFindConstantFullyQualifiedType) { CelFunctionRegistry func_registry; Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()); + type_registry_.GetTypeProvider()); auto type_value = resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); @@ -152,8 +146,7 @@ TEST_F(ResolverTest, TestFindConstantQualifiedTypeDisabled) { CelFunctionRegistry func_registry; Resolver resolver("", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums(), false); + type_registry_.GetTypeProvider(), false); auto type_value = resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); EXPECT_FALSE(type_value); @@ -161,10 +154,10 @@ TEST_F(ResolverTest, TestFindConstantQualifiedTypeDisabled) { TEST_F(ResolverTest, FindTypeBySimpleName) { CelFunctionRegistry func_registry; - Resolver resolver( - "google.api.expr.runtime", func_registry.InternalGetRegistry(), - type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), type_registry_.resolveable_enums()); + Resolver resolver("google.api.expr.runtime", + func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); ASSERT_OK_AND_ASSIGN(auto type, resolver.FindType("TestMessage", -1)); EXPECT_TRUE(type.has_value()); @@ -173,10 +166,10 @@ TEST_F(ResolverTest, FindTypeBySimpleName) { TEST_F(ResolverTest, FindTypeByQualifiedName) { CelFunctionRegistry func_registry; - Resolver resolver( - "google.api.expr.runtime", func_registry.InternalGetRegistry(), - type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), type_registry_.resolveable_enums()); + Resolver resolver("google.api.expr.runtime", + func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); ASSERT_OK_AND_ASSIGN( auto type, resolver.FindType(".google.api.expr.runtime.TestMessage", -1)); @@ -186,10 +179,10 @@ TEST_F(ResolverTest, FindTypeByQualifiedName) { TEST_F(ResolverTest, TestFindDescriptorNotFound) { CelFunctionRegistry func_registry; - Resolver resolver( - "google.api.expr.runtime", func_registry.InternalGetRegistry(), - type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), type_registry_.resolveable_enums()); + Resolver resolver("google.api.expr.runtime", + func_registry.InternalGetRegistry(), + type_registry_.InternalGetModernRegistry(), + type_registry_.GetTypeProvider()); ASSERT_OK_AND_ASSIGN(auto type, resolver.FindType("UndefinedMessage", -1)); EXPECT_FALSE(type.has_value()) << type->second; @@ -206,8 +199,7 @@ TEST_F(ResolverTest, TestFindOverloads) { Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()); + type_registry_.GetTypeProvider()); auto overloads = resolver.FindOverloads("fake_func", false, ArgumentsMatcher(0)); @@ -231,8 +223,7 @@ TEST_F(ResolverTest, TestFindLazyOverloads) { Resolver resolver("cel", func_registry.InternalGetRegistry(), type_registry_.InternalGetModernRegistry(), - type_registry_.GetTypeProvider(), - type_registry_.resolveable_enums()); + type_registry_.GetTypeProvider()); auto overloads = resolver.FindLazyOverloads("fake_lazy_func", false, ArgumentsMatcher(0)); diff --git a/runtime/BUILD b/runtime/BUILD index 22f854057..0c32fbdce 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -170,12 +170,15 @@ cc_library( deps = [ "//base:data", "//common:type", + "//common:value", "//runtime/internal:legacy_runtime_type_provider", "//runtime/internal:runtime_type_provider", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_protobuf//:protobuf", ], ) diff --git a/runtime/type_registry.cc b/runtime/type_registry.cc index 73a31d62c..f0520d4ef 100644 --- a/runtime/type_registry.cc +++ b/runtime/type_registry.cc @@ -14,13 +14,17 @@ #include "runtime/type_registry.h" +#include #include #include #include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "common/value.h" #include "runtime/internal/legacy_runtime_type_provider.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -39,8 +43,42 @@ TypeRegistry::TypeRegistry( void TypeRegistry::RegisterEnum(absl::string_view enum_name, std::vector enumerators) { + { + absl::MutexLock lock(&enum_value_table_mutex_); + enum_value_table_.reset(); + } enum_types_[enum_name] = Enumeration{std::string(enum_name), std::move(enumerators)}; } +std::shared_ptr> +TypeRegistry::GetEnumValueTable() const { + { + absl::ReaderMutexLock lock(&enum_value_table_mutex_); + if (enum_value_table_ != nullptr) { + return enum_value_table_; + } + } + + absl::MutexLock lock(&enum_value_table_mutex_); + if (enum_value_table_ != nullptr) { + return enum_value_table_; + } + std::shared_ptr> result = + std::make_shared>(); + + auto& enum_value_map = *result; + for (auto iter = enum_types_.begin(); iter != enum_types_.end(); ++iter) { + absl::string_view enum_name = iter->first; + const auto& enum_type = iter->second; + for (const auto& enumerator : enum_type.enumerators) { + auto key = absl::StrCat(enum_name, ".", enumerator.name); + enum_value_map[key] = cel::IntValue(enumerator.number); + } + } + + enum_value_table_ = result; + + return result; +} } // namespace cel diff --git a/runtime/type_registry.h b/runtime/type_registry.h index 8f3c9b06d..2b247946c 100644 --- a/runtime/type_registry.h +++ b/runtime/type_registry.h @@ -21,11 +21,14 @@ #include #include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "base/type_provider.h" #include "common/type.h" +#include "common/value.h" #include "runtime/internal/legacy_runtime_type_provider.h" #include "runtime/internal/runtime_type_provider.h" #include "google/protobuf/descriptor.h" @@ -40,6 +43,12 @@ const RuntimeTypeProvider& GetRuntimeTypeProvider( const TypeRegistry& type_registry); const absl::Nonnull>& GetLegacyRuntimeTypeProvider(const TypeRegistry& type_registry); + +// Returns a memoized table of fully qualified enum values. +// +// This is populated when first requested. +std::shared_ptr> +GetEnumValueTable(const TypeRegistry& type_registry); } // namespace runtime_internal // TypeRegistry manages composing TypeProviders used with a Runtime. @@ -100,10 +109,29 @@ class TypeRegistry { runtime_internal::GetLegacyRuntimeTypeProvider( const TypeRegistry& type_registry); + friend std::shared_ptr> + runtime_internal::GetEnumValueTable(const TypeRegistry& type_registry); + + std::shared_ptr> + GetEnumValueTable() const; + runtime_internal::RuntimeTypeProvider type_provider_; absl::Nonnull> legacy_type_provider_; absl::flat_hash_map enum_types_; + + // memoized fully qualified enumerator names. + // + // populated when requested. + // + // In almost all cases, this is built once and never updated, but we can't + // guarantee that with the current CelExpressionBuilder API. + // + // The cases when invalidation may occur are likely already race conditions, + // but we provide basic thread safety to avoid issues with sanitizers. + mutable std::shared_ptr> + enum_value_table_ ABSL_GUARDED_BY(enum_value_table_mutex_); + mutable absl::Mutex enum_value_table_mutex_; }; namespace runtime_internal { @@ -115,6 +143,11 @@ inline const absl::Nonnull>& GetLegacyRuntimeTypeProvider(const TypeRegistry& type_registry) { return type_registry.legacy_type_provider_; } +inline std::shared_ptr> +GetEnumValueTable(const TypeRegistry& type_registry) { + return type_registry.GetEnumValueTable(); +} + } // namespace runtime_internal } // namespace cel