From 7e6b6329af6b4025279c95608aab5de7c5608e82 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 27 Nov 2023 14:39:53 -0800 Subject: [PATCH] Add least and greatest macros for Math extension library PiperOrigin-RevId: 585764103 --- extensions/BUILD | 8 ++ extensions/math_ext.cc | 187 +++++++++++++++++++++++++++++++- extensions/math_ext.h | 7 ++ extensions/math_ext_test.cc | 205 +++++++++++++++++++++++++++++++++++- 4 files changed, 402 insertions(+), 5 deletions(-) diff --git a/extensions/BUILD b/extensions/BUILD index 33ab163a7..6e3be1616 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -24,9 +24,13 @@ cc_library( "//eval/public:cel_value", "//eval/public:portable_cel_function_adapter", "//internal:status_macros", + "//parser:macro", + "//parser:source_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -37,15 +41,19 @@ cc_test( deps = [ ":math_ext", "//eval/public:activation", + "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", "//eval/public:cel_expression", + "//eval/public:cel_function", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public/containers:container_backed_list_impl", "//eval/public/testing:matchers", "//internal:testing", + "//parser", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index ce7731bcb..a5b43b6e9 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -16,22 +16,32 @@ #include #include +#include +#include +#include +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_number.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/portable_cel_function_adapter.h" #include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/source_factory.h" namespace cel::extensions { namespace { +using ::google::api::expr::v1alpha1::Expr; +using ::google::api::expr::parser::SourceFactory; using ::google::api::expr::runtime::CelFunctionRegistry; using ::google::api::expr::runtime::CelList; using ::google::api::expr::runtime::CelNumber; @@ -43,8 +53,21 @@ using ::google::api::expr::runtime::PortableBinaryFunctionAdapter; using ::google::api::expr::runtime::PortableUnaryFunctionAdapter; using ::google::protobuf::Arena; -constexpr absl::string_view kMathMin = "math.@min"; -constexpr absl::string_view kMathMax = "math.@max"; +static constexpr absl::string_view kMathNamespace = "math"; +static constexpr absl::string_view kLeast = "least"; +static constexpr absl::string_view kGreatest = "greatest"; + +static constexpr char kMathMin[] = "math.@min"; +static constexpr char kMathMax[] = "math.@max"; + +bool isTargetNamespace(const Expr &target) { + switch (target.expr_kind_case()) { + case Expr::kIdentExpr: + return target.ident_expr().name() == kMathNamespace; + default: + return false; + } +} struct ToValueVisitor { CelValue operator()(uint64_t v) const { return CelValue::CreateUint64(v); } @@ -77,6 +100,11 @@ CelValue MinValue(CelNumber v1, CelNumber v2) { return NumberToValue(MinNumber(v1, v2)); } +template +CelValue Identity(Arena *arena, T v1) { + return NumberToValue(CelNumber(v1)); +} + template CelValue Min(Arena *arena, T v1, U v2) { return MinValue(CelNumber(v1), CelNumber(v2)); @@ -174,10 +202,71 @@ absl::Status RegisterCrossNumericMax(CelFunctionRegistry *registry) { return absl::OkStatus(); } +bool isValidArgType(const Expr &arg) { + switch (arg.expr_kind_case()) { + case google::api::expr::v1alpha1::Expr::kConstExpr: + if (arg.const_expr().has_double_value() || + arg.const_expr().has_int64_value() || + arg.const_expr().has_uint64_value()) { + return true; + } + return false; + case google::api::expr::v1alpha1::Expr::kListExpr: + case google::api::expr::v1alpha1::Expr::kStructExpr: // fall through + return false; + default: + return true; + } +} + +absl::optional checkInvalidArgs(const std::shared_ptr &sf, + const absl::string_view macro, + const std::vector &args) { + for (const auto &arg : args) { + if (!isValidArgType(arg)) { + return absl::optional(sf->ReportError( + arg.id(), + absl::StrCat(macro, " simple literal arguments must be numeric"))); + } + } + + return absl::nullopt; +} + +bool isListLiteralWithValidArgs(const Expr &arg) { + switch (arg.expr_kind_case()) { + case google::api::expr::v1alpha1::Expr::kListExpr: { + const auto &list_expr = arg.list_expr(); + if (list_expr.elements().empty()) { + return false; + } + + for (const auto &elem : list_expr.elements()) { + if (!isValidArgType(elem)) { + return false; + } + } + return true; + } + default: { + return false; + } + } +} + } // namespace absl::Status RegisterMathExtensionFunctions(CelFunctionRegistry *registry, const InterpreterOptions &options) { + CEL_RETURN_IF_ERROR(registry->Register( + PortableUnaryFunctionAdapter::Create( + kMathMin, /*receiver_style=*/false, &Identity))); + CEL_RETURN_IF_ERROR( + registry->Register(PortableUnaryFunctionAdapter::Create( + kMathMin, /*receiver_style=*/false, &Identity))); + CEL_RETURN_IF_ERROR(registry->Register( + PortableUnaryFunctionAdapter::Create( + kMathMin, /*receiver_style=*/false, &Identity))); CEL_RETURN_IF_ERROR(registry->Register( PortableBinaryFunctionAdapter::Create( kMathMin, /*receiver_style=*/false, &Min))); @@ -194,6 +283,15 @@ absl::Status RegisterMathExtensionFunctions(CelFunctionRegistry *registry, PortableUnaryFunctionAdapter::Create( kMathMin, false, MinList))); + CEL_RETURN_IF_ERROR(registry->Register( + PortableUnaryFunctionAdapter::Create( + kMathMax, /*receiver_style=*/false, &Identity))); + CEL_RETURN_IF_ERROR( + registry->Register(PortableUnaryFunctionAdapter::Create( + kMathMax, /*receiver_style=*/false, &Identity))); + CEL_RETURN_IF_ERROR(registry->Register( + PortableUnaryFunctionAdapter::Create( + kMathMax, /*receiver_style=*/false, &Identity))); CEL_RETURN_IF_ERROR(registry->Register( PortableBinaryFunctionAdapter::Create( kMathMax, /*receiver_style=*/false, &Max))); @@ -213,4 +311,89 @@ absl::Status RegisterMathExtensionFunctions(CelFunctionRegistry *registry, return absl::OkStatus(); } +std::vector math_macros() { + absl::StatusOr least = Macro::ReceiverVarArg( + kLeast, [](const std::shared_ptr &sf, int64_t macro_id, + const Expr &target, const std::vector &args) { + if (!isTargetNamespace(target)) { + return Expr(); + } + + switch (args.size()) { + case 0: + return sf->ReportError( + target.id(), "math.least() requires at least one argument."); + case 1: { + if (!isListLiteralWithValidArgs(args[0]) && + !isValidArgType(args[0])) { + return sf->ReportError( + args[0].id(), "math.least() invalid single argument value."); + } + + return sf->NewGlobalCallForMacro(target.id(), kMathMin, args); + } + case 2: { + auto error = checkInvalidArgs(sf, "math.least()", args); + if (error.has_value()) { + return *error; + } + + return sf->NewGlobalCallForMacro(target.id(), kMathMin, args); + } + default: + auto error = checkInvalidArgs(sf, "math.least()", args); + if (error.has_value()) { + return *error; + } + + return sf->NewGlobalCallForMacro( + target.id(), kMathMin, + {sf->NewList(sf->NextMacroId(macro_id), args)}); + } + }); + absl::StatusOr greatest = Macro::ReceiverVarArg( + kGreatest, [](const std::shared_ptr &sf, int64_t macro_id, + const Expr &target, const std::vector &args) { + if (!isTargetNamespace(target)) { + return Expr(); + } + + switch (args.size()) { + case 0: { + return sf->ReportError( + target.id(), "math.greatest() requires at least one argument."); + } + case 1: { + if (!isListLiteralWithValidArgs(args[0]) && + !isValidArgType(args[0])) { + return sf->ReportError( + args[0].id(), + "math.greatest() invalid single argument value."); + } + + return sf->NewGlobalCallForMacro(target.id(), kMathMax, args); + } + case 2: { + auto error = checkInvalidArgs(sf, "math.greatest()", args); + if (error.has_value()) { + return *error; + } + return sf->NewGlobalCallForMacro(target.id(), kMathMax, args); + } + default: { + auto error = checkInvalidArgs(sf, "math.greatest()", args); + if (error.has_value()) { + return *error; + } + + return sf->NewGlobalCallForMacro( + target.id(), kMathMax, + {sf->NewList(sf->NextMacroId(macro_id), args)}); + } + } + }); + + return {*least, *greatest}; +} + } // namespace cel::extensions diff --git a/extensions/math_ext.h b/extensions/math_ext.h index 06ae317a1..fcd9aab3d 100644 --- a/extensions/math_ext.h +++ b/extensions/math_ext.h @@ -15,12 +15,19 @@ #ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_ +#include + #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "parser/macro.h" namespace cel::extensions { +// math_macros() returns the namespaced helper macros for math.least() and +// math.greatest(). +std::vector math_macros(); + // Register extension functions for supporting mathematical operations above // and beyond the set defined in the CEL standard environment. absl::Status RegisterMathExtensionFunctions( diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index 4584dd122..febe1100b 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -15,34 +15,42 @@ #include "extensions/math_ext.h" #include -#include #include "google/api/expr/v1alpha1/syntax.pb.h" -#include "google/protobuf/arena.h" #include "absl/status/status.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/testing/matchers.h" #include "internal/testing.h" +#include "parser/parser.h" +#include "google/protobuf/arena.h" namespace cel::extensions { namespace { using ::google::api::expr::v1alpha1::Expr; +using ::google::api::expr::v1alpha1::ParsedExpr; using ::google::api::expr::v1alpha1::SourceInfo; +using ::google::api::expr::parser::ParseWithMacros; using ::google::api::expr::runtime::Activation; using ::google::api::expr::runtime::CelExpressionBuilder; +using ::google::api::expr::runtime::CelFunction; +using ::google::api::expr::runtime::CelFunctionDescriptor; using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::ContainerBackedListImpl; using ::google::api::expr::runtime::CreateCelExpressionBuilder; using ::google::api::expr::runtime::InterpreterOptions; - +using ::google::api::expr::runtime::RegisterBuiltinFunctions; using ::google::api::expr::runtime::test::EqualsCelValue; +using ::google::protobuf::Arena; using testing::HasSubstr; using cel::internal::StatusIs; @@ -72,6 +80,37 @@ TestCase MaxCase(CelValue list, CelValue result) { return TestCase{kMathMax, list, absl::nullopt, result}; } +struct MacroTestCase { + absl::string_view expr; + absl::string_view err = ""; +}; + +class TestFunction : public CelFunction { + public: + explicit TestFunction(absl::string_view name) + : CelFunction(CelFunctionDescriptor( + name, true, + {CelValue::Type::kBool, CelValue::Type::kInt64, + CelValue::Type::kInt64})) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + Arena* arena) const override { + *result = CelValue::CreateBool(true); + return absl::OkStatus(); + } +}; + +// Test function used to test macro collision and non-expansion. +constexpr absl::string_view kGreatest = "greatest"; +std::unique_ptr CreateGreatestFunction() { + return std::make_unique(kGreatest); +} + +constexpr absl::string_view kLeast = "least"; +std::unique_ptr CreateLeastFunction() { + return std::make_unique(kLeast); +} + Expr CallExprOneArg(absl::string_view operation) { Expr expr; auto call = expr.mutable_call_expr(); @@ -235,5 +274,165 @@ TEST(MathExtTest, MinMaxList) { ExpectResult(MaxCase(CelValue::CreateList(&bad_middle_item), err_value)); } +using MathExtMacroParamsTest = testing::TestWithParam; +TEST_P(MathExtMacroParamsTest, MacroTests) { + const MacroTestCase& test_case = GetParam(); + auto result = ParseWithMacros(test_case.expr, cel::extensions::math_macros(), + ""); + if (!test_case.err.empty()) { + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(test_case.err))); + return; + } + ASSERT_OK(result); + + ParsedExpr parsed_expr = *result; + Expr expr = parsed_expr.expr(); + SourceInfo source_info = parsed_expr.source_info(); + InterpreterOptions options; + std::unique_ptr builder = + CreateCelExpressionBuilder(options); + ASSERT_OK(builder->GetRegistry()->Register(CreateGreatestFunction())); + ASSERT_OK(builder->GetRegistry()->Register(CreateLeastFunction())); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); + ASSERT_OK(RegisterMathExtensionFunctions(builder->GetRegistry(), options)); + ASSERT_OK_AND_ASSIGN(auto cel_expression, + builder->CreateExpression(&expr, &source_info)); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, + cel_expression->Evaluate(activation, &arena)); + + ASSERT_TRUE(value.IsBool()); + EXPECT_EQ(value.BoolOrDie(), true); +} + +INSTANTIATE_TEST_SUITE_P( + MathExtMacrosParamsTest, MathExtMacroParamsTest, + testing::ValuesIn({ + // Tests for math.least + {"math.least(-0.5) == -0.5"}, + {"math.least(-1) == -1"}, + {"math.least(1u) == 1u"}, + {"math.least(42.0, -0.5) == -0.5"}, + {"math.least(-1, 0) == -1"}, + {"math.least(-1, -1) == -1"}, + {"math.least(1u, 42u) == 1u"}, + {"math.least(42.0, -0.5, -0.25) == -0.5"}, + {"math.least(-1, 0, 1) == -1"}, + {"math.least(-1, -1, -1) == -1"}, + {"math.least(1u, 42u, 0u) == 0u"}, + // math.least two arg overloads across type. + {"math.least(1, 1.0) == 1"}, + {"math.least(1, -2.0) == -2.0"}, + {"math.least(2, 1u) == 1u"}, + {"math.least(1.5, 2) == 1.5"}, + {"math.least(1.5, -2) == -2"}, + {"math.least(2.5, 1u) == 1u"}, + {"math.least(1u, 2) == 1u"}, + {"math.least(1u, -2) == -2"}, + {"math.least(2u, 2.5) == 2u"}, + // math.least with dynamic values across type. + {"math.least(1u, dyn(42)) == 1"}, + {"math.least(1u, dyn(42), dyn(0.0)) == 0u"}, + // math.least with a list literal. + {"math.least([1u, 42u, 0u]) == 0u"}, + // math.least errors + { + "math.least()", + "math.least() requires at least one argument.", + }, + { + "math.least('hello')", + "math.least() invalid single argument value.", + }, + { + "math.least({})", + "math.least() invalid single argument value", + }, + { + "math.least([])", + "math.least() invalid single argument value", + }, + { + "math.least([1, true])", + "math.least() invalid single argument value", + }, + { + "math.least(1, true)", + "math.least() simple literal arguments must be numeric", + }, + { + "math.least(1, 2, true)", + "math.least() simple literal arguments must be numeric", + }, + + // Tests for math.greatest + {"math.greatest(-0.5) == -0.5"}, + {"math.greatest(-1) == -1"}, + {"math.greatest(1u) == 1u"}, + {"math.greatest(42.0, -0.5) == 42.0"}, + {"math.greatest(-1, 0) == 0"}, + {"math.greatest(-1, -1) == -1"}, + {"math.greatest(1u, 42u) == 42u"}, + {"math.greatest(42.0, -0.5, -0.25) == 42.0"}, + {"math.greatest(-1, 0, 1) == 1"}, + {"math.greatest(-1, -1, -1) == -1"}, + {"math.greatest(1u, 42u, 0u) == 42u"}, + // math.least two arg overloads across type. + {"math.greatest(1, 1.0) == 1"}, + {"math.greatest(1, -2.0) == 1"}, + {"math.greatest(2, 1u) == 2"}, + {"math.greatest(1.5, 2) == 2"}, + {"math.greatest(1.5, -2) == 1.5"}, + {"math.greatest(2.5, 1u) == 2.5"}, + {"math.greatest(1u, 2) == 2"}, + {"math.greatest(1u, -2) == 1u"}, + {"math.greatest(2u, 2.5) == 2.5"}, + // math.greatest with dynamic values across type. + {"math.greatest(1u, dyn(42)) == 42.0"}, + {"math.greatest(1u, dyn(0.0), 0u) == 1"}, + // math.greatest with a list literal + {"math.greatest([1u, dyn(0.0), 0u]) == 1"}, + // math.greatest errors + { + "math.greatest()", + "math.greatest() requires at least one argument.", + }, + { + "math.greatest('hello')", + "math.greatest() invalid single argument value.", + }, + { + "math.greatest({})", + "math.greatest() invalid single argument value", + }, + { + "math.greatest([])", + "math.greatest() invalid single argument value", + }, + { + "math.greatest([1, true])", + "math.greatest() invalid single argument value", + }, + { + "math.greatest(1, true)", + "math.greatest() simple literal arguments must be numeric", + }, + { + "math.greatest(1, 2, true)", + "math.greatest() simple literal arguments must be numeric", + }, + // Call signatures which trigger macro expansion, but which do not + // get expanded. The function just returns true. + { + "false.greatest(1,2)", + }, + { + "true.least(1,2)", + }, + })); + } // namespace } // namespace cel::extensions