Skip to content

Commit

Permalink
Add least and greatest macros for Math extension library
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586060412
  • Loading branch information
l46kok authored and copybara-github committed Nov 28, 2023
1 parent 0d158a3 commit ef81f13
Show file tree
Hide file tree
Showing 4 changed files with 402 additions and 5 deletions.
8 changes: 8 additions & 0 deletions extensions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ cc_library(
"//eval/public:cel_value",
"//eval/public:portable_cel_function_adapter",
"//internal:status_macros",
"//parser:macro",
"//parser:source_factory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto",
"@com_google_protobuf//:protobuf",
],
)
Expand All @@ -37,15 +41,19 @@ cc_test(
deps = [
":math_ext",
"//eval/public:activation",
"//eval/public:builtin_func_registrar",
"//eval/public:cel_expr_builder_factory",
"//eval/public:cel_expression",
"//eval/public:cel_function",
"//eval/public:cel_options",
"//eval/public:cel_value",
"//eval/public/containers:container_backed_list_impl",
"//eval/public/testing:matchers",
"//internal:testing",
"//parser",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto",
"@com_google_protobuf//:protobuf",
],
Expand Down
187 changes: 185 additions & 2 deletions extensions/math_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,32 @@

#include <algorithm>
#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

#include "google/api/expr/v1alpha1/syntax.pb.h"
#include "google/protobuf/arena.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "eval/public/cel_function_registry.h"
#include "eval/public/cel_number.h"
#include "eval/public/cel_options.h"
#include "eval/public/cel_value.h"
#include "eval/public/portable_cel_function_adapter.h"
#include "internal/status_macros.h"
#include "parser/macro.h"
#include "parser/source_factory.h"

namespace cel::extensions {

namespace {

using ::google::api::expr::v1alpha1::Expr;
using ::google::api::expr::parser::SourceFactory;
using ::google::api::expr::runtime::CelFunctionRegistry;
using ::google::api::expr::runtime::CelList;
using ::google::api::expr::runtime::CelNumber;
Expand All @@ -43,8 +53,21 @@ using ::google::api::expr::runtime::PortableBinaryFunctionAdapter;
using ::google::api::expr::runtime::PortableUnaryFunctionAdapter;
using ::google::protobuf::Arena;

constexpr absl::string_view kMathMin = "math.@min";
constexpr absl::string_view kMathMax = "math.@max";
static constexpr absl::string_view kMathNamespace = "math";
static constexpr absl::string_view kLeast = "least";
static constexpr absl::string_view kGreatest = "greatest";

static constexpr char kMathMin[] = "math.@min";
static constexpr char kMathMax[] = "math.@max";

bool isTargetNamespace(const Expr &target) {
switch (target.expr_kind_case()) {
case Expr::kIdentExpr:
return target.ident_expr().name() == kMathNamespace;
default:
return false;
}
}

struct ToValueVisitor {
CelValue operator()(uint64_t v) const { return CelValue::CreateUint64(v); }
Expand Down Expand Up @@ -77,6 +100,11 @@ CelValue MinValue(CelNumber v1, CelNumber v2) {
return NumberToValue(MinNumber(v1, v2));
}

template <typename T>
CelValue Identity(Arena *arena, T v1) {
return NumberToValue(CelNumber(v1));
}

template <typename T, typename U>
CelValue Min(Arena *arena, T v1, U v2) {
return MinValue(CelNumber(v1), CelNumber(v2));
Expand Down Expand Up @@ -174,10 +202,71 @@ absl::Status RegisterCrossNumericMax(CelFunctionRegistry *registry) {
return absl::OkStatus();
}

bool isValidArgType(const Expr &arg) {
switch (arg.expr_kind_case()) {
case google::api::expr::v1alpha1::Expr::kConstExpr:
if (arg.const_expr().has_double_value() ||
arg.const_expr().has_int64_value() ||
arg.const_expr().has_uint64_value()) {
return true;
}
return false;
case google::api::expr::v1alpha1::Expr::kListExpr:
case google::api::expr::v1alpha1::Expr::kStructExpr: // fall through
return false;
default:
return true;
}
}

absl::optional<Expr> checkInvalidArgs(const std::shared_ptr<SourceFactory> &sf,
const absl::string_view macro,
const std::vector<Expr> &args) {
for (const auto &arg : args) {
if (!isValidArgType(arg)) {
return absl::optional<Expr>(sf->ReportError(
arg.id(),
absl::StrCat(macro, " simple literal arguments must be numeric")));
}
}

return absl::nullopt;
}

bool isListLiteralWithValidArgs(const Expr &arg) {
switch (arg.expr_kind_case()) {
case google::api::expr::v1alpha1::Expr::kListExpr: {
const auto &list_expr = arg.list_expr();
if (list_expr.elements().empty()) {
return false;
}

for (const auto &elem : list_expr.elements()) {
if (!isValidArgType(elem)) {
return false;
}
}
return true;
}
default: {
return false;
}
}
}

} // namespace

absl::Status RegisterMathExtensionFunctions(CelFunctionRegistry *registry,
const InterpreterOptions &options) {
CEL_RETURN_IF_ERROR(registry->Register(
PortableUnaryFunctionAdapter<CelValue, int64_t>::Create(
kMathMin, /*receiver_style=*/false, &Identity<int64_t>)));
CEL_RETURN_IF_ERROR(
registry->Register(PortableUnaryFunctionAdapter<CelValue, double>::Create(
kMathMin, /*receiver_style=*/false, &Identity<double>)));
CEL_RETURN_IF_ERROR(registry->Register(
PortableUnaryFunctionAdapter<CelValue, uint64_t>::Create(
kMathMin, /*receiver_style=*/false, &Identity<uint64_t>)));
CEL_RETURN_IF_ERROR(registry->Register(
PortableBinaryFunctionAdapter<CelValue, int64_t, int64_t>::Create(
kMathMin, /*receiver_style=*/false, &Min<int64_t, int64_t>)));
Expand All @@ -194,6 +283,15 @@ absl::Status RegisterMathExtensionFunctions(CelFunctionRegistry *registry,
PortableUnaryFunctionAdapter<CelValue, const CelList *>::Create(
kMathMin, false, MinList)));

CEL_RETURN_IF_ERROR(registry->Register(
PortableUnaryFunctionAdapter<CelValue, int64_t>::Create(
kMathMax, /*receiver_style=*/false, &Identity<int64_t>)));
CEL_RETURN_IF_ERROR(
registry->Register(PortableUnaryFunctionAdapter<CelValue, double>::Create(
kMathMax, /*receiver_style=*/false, &Identity<double>)));
CEL_RETURN_IF_ERROR(registry->Register(
PortableUnaryFunctionAdapter<CelValue, uint64_t>::Create(
kMathMax, /*receiver_style=*/false, &Identity<uint64_t>)));
CEL_RETURN_IF_ERROR(registry->Register(
PortableBinaryFunctionAdapter<CelValue, int64_t, int64_t>::Create(
kMathMax, /*receiver_style=*/false, &Max<int64_t, int64_t>)));
Expand All @@ -213,4 +311,89 @@ absl::Status RegisterMathExtensionFunctions(CelFunctionRegistry *registry,
return absl::OkStatus();
}

std::vector<Macro> math_macros() {
absl::StatusOr<Macro> least = Macro::ReceiverVarArg(
kLeast, [](const std::shared_ptr<SourceFactory> &sf, int64_t macro_id,
const Expr &target, const std::vector<Expr> &args) {
if (!isTargetNamespace(target)) {
return Expr();
}

switch (args.size()) {
case 0:
return sf->ReportError(
target.id(), "math.least() requires at least one argument.");
case 1: {
if (!isListLiteralWithValidArgs(args[0]) &&
!isValidArgType(args[0])) {
return sf->ReportError(
args[0].id(), "math.least() invalid single argument value.");
}

return sf->NewGlobalCallForMacro(target.id(), kMathMin, args);
}
case 2: {
auto error = checkInvalidArgs(sf, "math.least()", args);
if (error.has_value()) {
return *error;
}

return sf->NewGlobalCallForMacro(target.id(), kMathMin, args);
}
default:
auto error = checkInvalidArgs(sf, "math.least()", args);
if (error.has_value()) {
return *error;
}

return sf->NewGlobalCallForMacro(
target.id(), kMathMin,
{sf->NewList(sf->NextMacroId(macro_id), args)});
}
});
absl::StatusOr<Macro> greatest = Macro::ReceiverVarArg(
kGreatest, [](const std::shared_ptr<SourceFactory> &sf, int64_t macro_id,
const Expr &target, const std::vector<Expr> &args) {
if (!isTargetNamespace(target)) {
return Expr();
}

switch (args.size()) {
case 0: {
return sf->ReportError(
target.id(), "math.greatest() requires at least one argument.");
}
case 1: {
if (!isListLiteralWithValidArgs(args[0]) &&
!isValidArgType(args[0])) {
return sf->ReportError(
args[0].id(),
"math.greatest() invalid single argument value.");
}

return sf->NewGlobalCallForMacro(target.id(), kMathMax, args);
}
case 2: {
auto error = checkInvalidArgs(sf, "math.greatest()", args);
if (error.has_value()) {
return *error;
}
return sf->NewGlobalCallForMacro(target.id(), kMathMax, args);
}
default: {
auto error = checkInvalidArgs(sf, "math.greatest()", args);
if (error.has_value()) {
return *error;
}

return sf->NewGlobalCallForMacro(
target.id(), kMathMax,
{sf->NewList(sf->NextMacroId(macro_id), args)});
}
}
});

return {*least, *greatest};
}

} // namespace cel::extensions
7 changes: 7 additions & 0 deletions extensions/math_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_
#define THIRD_PARTY_CEL_CPP_EXTENSIONS_MATH_EXT_H_

#include <vector>

#include "absl/status/status.h"
#include "eval/public/cel_function_registry.h"
#include "eval/public/cel_options.h"
#include "parser/macro.h"

namespace cel::extensions {

// math_macros() returns the namespaced helper macros for math.least() and
// math.greatest().
std::vector<Macro> math_macros();

// Register extension functions for supporting mathematical operations above
// and beyond the set defined in the CEL standard environment.
absl::Status RegisterMathExtensionFunctions(
Expand Down
Loading

0 comments on commit ef81f13

Please sign in to comment.