diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 000000000..25939d35c --- /dev/null +++ b/.bazelversion @@ -0,0 +1 @@ +0.29.1 diff --git a/.travis.yml b/.travis.yml index 9bb6919e7..8f7651b4b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,8 +11,8 @@ addons: - clang-8 before_install: - - wget https://github.com/bazelbuild/bazel/releases/download/0.28.1/bazel_0.28.1-linux-x86_64.deb - - sudo dpkg -i bazel_0.28.1-linux-x86_64.deb + - wget https://github.com/bazelbuild/bazel/releases/download/0.29.1/bazel_0.29.1-linux-x86_64.deb + - sudo dpkg -i bazel_0.29.1-linux-x86_64.deb script: - export CC=clang-8 diff --git a/WORKSPACE b/WORKSPACE index 07b1e34f0..c9893a040 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -63,11 +63,13 @@ http_archive( urls = ["https://github.com/google/benchmark/archive/master.zip"], ) -CEL_SPEC_SHA="dd75cc98926a52975d303c9a635f18ab0aa1f2b8" +CEL_SPEC_GIT_SHA="b154461b3a037f9654852087ef96be2b756871a0" # 10/16/2019 +CEL_SPEC_SHA="a88cf903fc890cb8e53048365d05a5c0c03e35148b03812de7a471d7d2ff8744" http_archive( name = "com_google_cel_spec", - strip_prefix = "cel-spec-" + CEL_SPEC_SHA, - urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_SHA + ".zip"], + sha256 = CEL_SPEC_SHA, + strip_prefix = "cel-spec-" + CEL_SPEC_GIT_SHA, + urls = ["https://github.com/google/cel-spec/archive/" + CEL_SPEC_GIT_SHA + ".zip"], ) # Google RE2 (Regular Expression) C++ Library @@ -77,12 +79,14 @@ http_archive( urls = ["https://github.com/google/re2/archive/master.zip"], ) -GOOGLEAPIS_SHA = "184ab77f4cee62332f8f9a689c70c9bea441f836" +GOOGLEAPIS_GIT_SHA = "be480e391cc88a75cf2a81960ef79c80d5012068" # Jul 24, 2019 +GOOGLEAPIS_SHA = "c1969e5b72eab6d9b6cfcff748e45ba57294aeea1d96fd04cd081995de0605c2" + http_archive( name = "com_google_googleapis", - sha256 = "a3a8c83314e5a431473659cb342a11e5520c6de4790eee70633d578f278b1e73", - strip_prefix = "googleapis-" + GOOGLEAPIS_SHA, - urls = ["https://github.com/googleapis/googleapis/archive/" + GOOGLEAPIS_SHA + ".tar.gz"], + sha256 = GOOGLEAPIS_SHA, + strip_prefix = "googleapis-" + GOOGLEAPIS_GIT_SHA, + urls = ["https://github.com/googleapis/googleapis/archive/" + GOOGLEAPIS_GIT_SHA + ".tar.gz"], ) load("@com_google_googleapis//:repository_rules.bzl", "switched_rules_by_language") diff --git a/common/escaping.cc b/common/escaping.cc new file mode 100644 index 000000000..b3c7cdf5e --- /dev/null +++ b/common/escaping.cc @@ -0,0 +1,361 @@ +#include "common/escaping.h" + +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" +#include "util/utf8/public/unicodetext.h" + +namespace google { +namespace api { +namespace expr { +namespace parser { + +inline std::pair unhex(char c) { + if ('0' <= c && c <= '9') { + return std::make_pair(c - '0', true); + } + if ('a' <= c && c <= 'f') { + return std::make_pair(c - 'a' + 10, true); + } + if ('A' <= c && c <= 'F') { + return std::make_pair(c - 'A' + 10, true); + } + return std::make_pair(0, false); +} + +// unescape_char takes a string input and returns the following info: +// +// value - the escaped unicode rune at the front of the string. +// encode - the value should be unicode-encoded +// tail - the remainder of the input string. +// err - error value, if the character could not be unescaped. +// +// When encode is true the return value may still fit within a single byte, +// but unicode encoding is attempted which is more expensive than when the +// value is known to self-represent as a single byte. +// +// If is_bytes is set, unescape as a bytes literal so octal and hex escapes +// represent byte values, not unicode code points. +inline std::tuple unescape_char( + std::string_view s, bool is_bytes) { + char c = s[0]; + + // 1. Character is not an escape sequence. + if (c >= 0x80 && !is_bytes) { + UnicodeText ut; + ut.PointToUTF8(s.data(), s.size()); + auto r = ut.begin(); + char tmp[5]; + int len = r.get_utf8(tmp); + tmp[len] = '\0'; + return std::make_tuple(std::string(tmp), s.substr(len), ""); + } else if (c != '\\') { + char tmp[2] = {c, '\0'}; + return std::make_tuple(std::string(tmp), s.substr(1), ""); + } + + // 2. Last character is the start of an escape sequence. + if (s.size() <= 1) { + return std::make_tuple("", s, + "unable to unescape string, " + "found '\\' as last character"); + } + + c = s[1]; + s = s.substr(2); + + char32_t value; + bool encode = false; + + // 3. Common escape sequences shared with Google SQL + switch (c) { + case 'a': + value = '\a'; + break; + case 'b': + value = '\b'; + break; + case 'f': + value = '\f'; + break; + case 'n': + value = '\n'; + break; + case 'r': + value = '\r'; + break; + case 't': + value = '\t'; + break; + case 'v': + value = '\v'; + break; + case '\\': + value = '\\'; + break; + case '\'': + value = '\''; + break; + case '"': + value = '"'; + break; + case '`': + value = '`'; + break; + case '?': + value = '?'; + break; + + // 4. Unicode escape sequences, reproduced from `strconv/quote.go` + case 'x': + [[fallthrough]]; + case 'X': + [[fallthrough]]; + case 'u': + [[fallthrough]]; + case 'U': { + int n = 0; + encode = true; + switch (c) { + case 'x': + [[fallthrough]]; + case 'X': + n = 2; + encode = !is_bytes; + break; + case 'u': + n = 4; + if (is_bytes) { + return std::make_tuple("", s, + "unable to unescape string " + "(\\u in bytes)"); + } + break; + case 'U': + n = 8; + if (is_bytes) { + return std::make_tuple("", s, + "unable to unescape string " + "(\\U in bytes)"); + } + break; + } + char32_t v = 0; + if (s.size() < n) { + return std::make_tuple("", s, + "unable to unescape string " + "(string too short after \\xXuU)"); + } + for (int j = 0; j < n; ++j) { + auto x = unhex(s[j]); + if (!x.second) { + return std::make_tuple("", s, + "unable to unescape string " + "(invalid hex)"); + } + v = v << 4 | x.first; + } + s = s.substr(n); + if (!is_bytes && v > 0x0010FFFF) { + return std::make_tuple("", s, + "unable to unescape string" + "(value out of bounds)"); + } + value = v; + break; + } + + // 5. Octal escape sequences, must be three digits \[0-3][0-7][0-7] + case '0': + [[fallthrough]]; + case '1': + [[fallthrough]]; + case '2': + [[fallthrough]]; + case '3': { + if (s.size() < 2) { + return std::make_tuple("", s, + "unable to unescape octal sequence in string"); + } + char32_t v = c - '0'; + for (int j = 0; j < 2; ++j) { + char x = s[j]; + if (x < '0' || x > '7') { + return std::make_tuple("", s, + "unable to unescape octal sequence " + "in string"); + } + v = v * 8 + (x - '0'); + } + if (!is_bytes && v > 0x0010FFFF) { + return std::make_tuple("", s, "unable to unescape string"); + } + value = v; + s = s.substr(2); + encode = !is_bytes; + } break; + + // Unknown escape sequence. + default: + return std::make_tuple("", s, "unable to unescape string"); + } + + if (value < 0x80 || !encode) { + char tmp[2] = {(char)value, '\0'}; + return std::make_tuple(std::string(tmp), s, ""); + } else { + UnicodeText ut; + ut.push_back(value); + return std::make_tuple(ut.begin().get_utf8_string(), s, ""); + } +} + +// Unescape takes a quoted string, unquotes, and unescapes it. +std::optional unescape(const std::string& s, bool is_bytes) { + // All strings normalize newlines to the \n representation. + std::string value = absl::StrReplaceAll(s, {{"\r\n", "\n"}, {"\r", "\n"}}); + + size_t n = value.size(); + + // Nothing to unescape / decode. + if (n < 2) { + return std::make_optional(value); + } + + // Raw string preceded by the 'r|R' prefix. + bool is_raw_literal = false; + if (value[0] == 'r' || value[0] == 'R') { + value.resize(value.size() - 1); + n = value.size(); + is_raw_literal = true; + } + + // Quoted string of some form, must have same first and last char. + if (value[0] != value[n - 1] || (value[0] != '"' && value[0] != '\'')) { + return std::optional(); + } + + // Normalize the multi-line CEL string representation to a standard + // Google SQL or Go quoted string, as accepted by CEL. + if (n >= 6) { + if (absl::StartsWith(value, "'''")) { + if (!absl::EndsWith(value, "'''")) { + return std::optional(); + } + value = "\"" + value.substr(3, n - 6) + "\""; + } else if (absl::StartsWith(value, "\"\"\"")) { + if (!absl::EndsWith(value, "\"\"\"")) { + return std::optional(); + } + value = "\"" + value.substr(3, n - 6) + "\""; + } + n = value.size(); + } + value = value.substr(1, n - 2); + // If there is nothing to escape, then return. + if (is_raw_literal || (value.find("\\") == std::string::npos)) { + return value; + } + + if (is_bytes) { + // first convert byte values the non-UTF8 way + std::string new_value; + for (std::string::size_type i = 0; i < value.size() - 1; ++i) { + if (value[i] == '\\') { + if (value[i + 1] == 'x' || value[i + 1] == 'X') { + if (i > (std::numeric_limits::max() - 3) || + i + 3 >= value.size()) { + return std::optional(); + } + char v = 0; + for (int j = 2; j <= 3; ++j) { + auto x = unhex(value[i + j]); + v = v << 4 | x.first; + } + i += 3; + new_value += v; + } else if (value[i + 1] == '0' || value[i + 1] == '1' || + value[i + 1] == '2' || value[i + 1] == '3') { + if (i > (std::numeric_limits::max() - 3) || + i + 3 >= value.size()) { + return std::optional(); + } + char v = value[i + 1] - '0'; + for (int j = 1; j <= 3; ++j) { + char x = value[i + j]; + if (x < '0' || x > '7') { + return std::optional(); + } + v = v * 8 + (x - '0'); + } + i += 3; + new_value += v; + } else { + return std::optional(); + } + } else { + new_value += value[i]; + } + } + value = std::move(new_value); + } + + std::string unescaped; + unescaped.reserve(3 * value.size() / 2); + std::string_view value_sv(value); + while (!value_sv.empty()) { + std::tuple c = + unescape_char(value_sv, is_bytes); + if (!std::get<2>(c).empty()) { + return std::optional(); + } + + unescaped.append(std::get<0>(c)); + value_sv = std::get<1>(c); + } + return std::make_optional(unescaped); +} + +std::string escapeAndQuote(std::string_view str) { + const std::string lowerhex = "0123456789abcdef"; + + std::string s; + for (auto c : str) { + switch (c) { + case '\a': + s.append("\\a"); + break; + case '\b': + s.append("\\b"); + break; + case '\f': + s.append("\\f"); + break; + case '\n': + s.append("\\n"); + break; + case '\r': + s.append("\\r"); + break; + case '\t': + s.append("\\t"); + break; + case '\v': + s.append("\\v"); + break; + case '"': + s.append("\\\""); + break; + default: + s += c; + break; + } + } + return absl::StrFormat("\"%s\"", s); +} + +} // namespace parser +} // namespace expr +} // namespace api +} // namespace google diff --git a/common/escaping.h b/common/escaping.h new file mode 100644 index 000000000..327591b72 --- /dev/null +++ b/common/escaping.h @@ -0,0 +1,23 @@ +#ifndef THIRD_PARTY_CEL_CPP_PARSER_UNESCAPE_H_ +#define THIRD_PARTY_CEL_CPP_PARSER_UNESCAPE_H_ + +#include +#include + +namespace google { +namespace api { +namespace expr { +namespace parser { + +// Unescape takes a quoted string, unquotes, and unescapes it. +std::optional unescape(const std::string& s, bool is_bytes); + +// Takes a string, and escapes values according to CEL and quotes +std::string escapeAndQuote(std::string_view str); + +} // namespace parser +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_PARSER_UNESCAPE_H_ diff --git a/common/escaping_test.cc b/common/escaping_test.cc new file mode 100644 index 000000000..dc58e8a02 --- /dev/null +++ b/common/escaping_test.cc @@ -0,0 +1,101 @@ +#include "common/escaping.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace google { +namespace api { +namespace expr { +namespace parser { +namespace { + +using testing::Eq; +using testing::Ne; + +struct TestInfo { + static constexpr char EXPECT_ERROR[] = "--ERROR--"; + + TestInfo(const std::string& I, const std::string& O, bool is_bytes = false) + : I(I), O(O), is_bytes(is_bytes) {} + + // Input string + std::string I; + + // Expected output string + std::string O; + + // Indicator whether this is a byte or text string + bool is_bytes; +}; + +std::vector test_cases = { + {"'hello'", "hello"}, + {R"("")", ""}, + {R"("\\\"")", R"(\")"}, + {R"("\\")", "\\"}, + {"'''x''x'''", "x''x"}, + {R"("""x""x""")", R"(x""x)"}, + // Octal 303 -> Code point 195 (Ã) + // Octal 277 -> Code point 191 (¿) + {R"("\303\277")", "ÿ"}, + // Octal 377 -> Code point 255 (ÿ) + {R"("\377")", "ÿ"}, + {R"("\u263A\u263A")", "☺☺"}, + {R"("\a\b\f\n\r\t\v\'\"\\\? Legal escapes")", + "\a\b\f\n\r\t\v'\"\\? Legal escapes"}, + // Illegal escape, expect error + {R"("\a\b\f\n\r\t\v\'\\"\\\? Illegal escape \>")", TestInfo::EXPECT_ERROR}, + {R"("\u1")", TestInfo::EXPECT_ERROR}, + + // The following are interpreted as byte sequences, hence "true" + {"\"abc\"", "\x61\x62\x63", true}, + {"\"ÿ\"", "\xc3\xbf", true}, + {R"("\303\277")", "\xc3\xbf", true}, + {R"("\377")", "\xff", true}, + {R"("\xc3\xbf")", "\xc3\xbf", true}, + {R"("\xff")", "\xff", true}, + // Bytes unicode escape, expect error + {R"("\u00ff")", TestInfo::EXPECT_ERROR, true}, + {R"("\z")", TestInfo::EXPECT_ERROR, true}, + {R"("\x1")", TestInfo::EXPECT_ERROR, true}, + {R"("\u1")", TestInfo::EXPECT_ERROR, true}, +}; + +class UnescapeTest : public testing::TestWithParam {}; + +TEST_P(UnescapeTest, Unescape) { + const TestInfo& test_info = GetParam(); + ::testing::internal::ColoredPrintf(::testing::internal::COLOR_GREEN, + "[ ]"); + ::testing::internal::ColoredPrintf(::testing::internal::COLOR_DEFAULT, + " Input: "); + ::testing::internal::ColoredPrintf(::testing::internal::COLOR_YELLOW, "%s%s", + test_info.I.c_str(), + test_info.is_bytes ? " BYTES" : ""); + if (test_info.O != TestInfo::EXPECT_ERROR) { + ::testing::internal::ColoredPrintf(::testing::internal::COLOR_DEFAULT, + " Expected Output: "); + ::testing::internal::ColoredPrintf(::testing::internal::COLOR_YELLOW, + "%s\n", test_info.O.c_str()); + } else { + ::testing::internal::ColoredPrintf(::testing::internal::COLOR_YELLOW, + " Expecting ERROR\n"); + } + + auto result = unescape(test_info.I, test_info.is_bytes); + if (test_info.O == TestInfo::EXPECT_ERROR) { + EXPECT_THAT(result, Eq(std::nullopt)); + } else { + ASSERT_THAT(result, Ne(std::nullopt)); + EXPECT_EQ(*result, test_info.O); + } +} + +INSTANTIATE_TEST_SUITE_P(UnescapeSuite, UnescapeTest, + testing::ValuesIn(test_cases)); + +} // namespace +} // namespace parser +} // namespace expr +} // namespace api +} // namespace google diff --git a/common/operators.cc b/common/operators.cc new file mode 100644 index 000000000..89fa3b10e --- /dev/null +++ b/common/operators.cc @@ -0,0 +1,202 @@ +#include "common/operators.h" + +#include + +namespace google { +namespace api { +namespace expr { +namespace common { + +namespace { +// These functions provide access to reverse mappings for operators. +// Functions generally map from text expression to Expr representation, +// e.g., from "&&" to "_&&_". Reverse operators provides a mapping from +// Expr to textual mapping, e.g., from "_&&_" to "&&". + +const std::map& UnaryOperators() { + static std::shared_ptr> unaries_map = [&]() { + auto u = + std::make_shared>(std::map{ + {CelOperator::NEGATE, "-"}, {CelOperator::LOGICAL_NOT, "!"}}); + return u; + }(); + return *unaries_map; +} + +const std::map& BinaryOperators() { + static std::shared_ptr> binops_map = [&]() { + auto c = std::make_shared>( + std::map{{CelOperator::LOGICAL_OR, "||"}, + {CelOperator::LOGICAL_AND, "&&"}, + {CelOperator::LESS_EQUALS, "<="}, + {CelOperator::LESS, "<"}, + {CelOperator::GREATER_EQUALS, ">="}, + {CelOperator::GREATER, ">"}, + {CelOperator::EQUALS, "=="}, + {CelOperator::NOT_EQUALS, "!="}, + {CelOperator::IN_DEPRECATED, "in"}, + {CelOperator::IN, "in"}, + {CelOperator::ADD, "+"}, + {CelOperator::SUBTRACT, "-"}, + {CelOperator::MULTIPLY, "*"}, + {CelOperator::DIVIDE, "/"}, + {CelOperator::MODULO, "%"}}); + return c; + }(); + return *binops_map; +} + +const std::map& ReverseOperators() { + static std::shared_ptr> operators_map = [&]() { + auto c = + std::make_shared>(std::map{ + {"+", CelOperator::ADD}, + {"-", CelOperator::SUBTRACT}, + {"*", CelOperator::MULTIPLY}, + {"/", CelOperator::DIVIDE}, + {"%", CelOperator::MODULO}, + {"==", CelOperator::EQUALS}, + {"!=", CelOperator::NOT_EQUALS}, + {">", CelOperator::GREATER}, + {">=", CelOperator::GREATER_EQUALS}, + {"<", CelOperator::LESS}, + {"<=", CelOperator::LESS_EQUALS}, + {"&&", CelOperator::LOGICAL_AND}, + {"!", CelOperator::LOGICAL_NOT}, + {"||", CelOperator::LOGICAL_OR}, + {"in", CelOperator::IN}, + }); + return c; + }(); + return *operators_map; +} + +const std::map& Operators() { + static std::shared_ptr> operators_map = [&]() { + auto c = std::make_shared>( + std::map{{CelOperator::ADD, "+"}, + {CelOperator::SUBTRACT, "-"}, + {CelOperator::MULTIPLY, "*"}, + {CelOperator::DIVIDE, "/"}, + {CelOperator::MODULO, "%"}, + {CelOperator::EQUALS, "=="}, + {CelOperator::NOT_EQUALS, "!="}, + {CelOperator::GREATER, ">"}, + {CelOperator::GREATER_EQUALS, ">="}, + {CelOperator::LESS, "<"}, + {CelOperator::LESS_EQUALS, "<="}, + {CelOperator::LOGICAL_AND, "&&"}, + {CelOperator::LOGICAL_NOT, "!"}, + {CelOperator::LOGICAL_OR, "||"}, + {CelOperator::IN, "in"}, + {CelOperator::IN_DEPRECATED, "in"}, + {CelOperator::NEGATE, "-"}}); + return c; + }(); + return *operators_map; +} + +// precedence of the operator, where the higher value means higher. +const std::map& Precedences() { + static std::shared_ptr> precedence_map = [&]() { + auto c = std::make_shared>( + std::map{{CelOperator::CONDITIONAL, 8}, + + {CelOperator::LOGICAL_OR, 7}, + + {CelOperator::LOGICAL_AND, 6}, + + {CelOperator::EQUALS, 5}, + {CelOperator::GREATER, 5}, + {CelOperator::GREATER_EQUALS, 5}, + {CelOperator::IN, 5}, + {CelOperator::LESS, 5}, + {CelOperator::LESS_EQUALS, 5}, + {CelOperator::NOT_EQUALS, 5}, + {CelOperator::IN_DEPRECATED, 5}, + + {CelOperator::ADD, 4}, + {CelOperator::SUBTRACT, 4}, + + {CelOperator::DIVIDE, 3}, + {CelOperator::MODULO, 3}, + {CelOperator::MULTIPLY, 3}, + + {CelOperator::LOGICAL_NOT, 2}, + {CelOperator::NEGATE, 2}, + + {CelOperator::INDEX, 1}}); + return c; + }(); + return *precedence_map; +} + +} // namespace + +int LookupPrecedence(const std::string& op) { + auto precs = Precedences(); + auto p = precs.find(op); + if (p != precs.end()) { + return p->second; + } + return 0; +} + +std::optional LookupUnaryOperator(const std::string& op) { + auto unary_ops = UnaryOperators(); + auto o = unary_ops.find(op); + if (o == unary_ops.end()) { + return std::optional(); + } + return o->second; +} + +std::optional LookupBinaryOperator(const std::string& op) { + auto bin_ops = BinaryOperators(); + auto o = bin_ops.find(op); + if (o == bin_ops.end()) { + return std::optional(); + } + return o->second; +} + +std::optional LookupOperator(const std::string& op) { + auto ops = Operators(); + auto o = ops.find(op); + if (o == ops.end()) { + return std::optional(); + } + return o->second; +} + +std::optional ReverseLookupOperator(const std::string& op) { + auto rev_ops = ReverseOperators(); + auto o = rev_ops.find(op); + if (o == rev_ops.end()) { + return std::optional(); + } + return o->second; +} + +bool IsOperatorSamePrecedence(const std::string& op, const Expr& expr) { + if (!expr.has_call_expr()) { + return false; + } + return LookupPrecedence(op) == LookupPrecedence(expr.call_expr().function()); +} + +bool IsOperatorLowerPrecedence(const std::string& op, const Expr& expr) { + if (!expr.has_call_expr()) { + return false; + } + return LookupPrecedence(op) < LookupPrecedence(expr.call_expr().function()); +} + +bool IsOperatorLeftRecursive(const std::string& op) { + return op != CelOperator::LOGICAL_AND && op != CelOperator::LOGICAL_OR; +} + +} // namespace common +} // namespace expr +} // namespace api +} // namespace google diff --git a/common/operators.h b/common/operators.h new file mode 100644 index 000000000..2ad0233f2 --- /dev/null +++ b/common/operators.h @@ -0,0 +1,70 @@ +#ifndef THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ + +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "absl/strings/string_view.h" + +namespace google { +namespace api { +namespace expr { +namespace common { + +// Operator function names. +struct CelOperator { + static constexpr const char* CONDITIONAL = "_?_:_"; + static constexpr const char* LOGICAL_AND = "_&&_"; + static constexpr const char* LOGICAL_OR = "_||_"; + static constexpr const char* LOGICAL_NOT = "!_"; + static constexpr const char* IN_DEPRECATED = "_in_"; + static constexpr const char* EQUALS = "_==_"; + static constexpr const char* NOT_EQUALS = "_!=_"; + static constexpr const char* LESS = "_<_"; + static constexpr const char* LESS_EQUALS = "_<=_"; + static constexpr const char* GREATER = "_>_"; + static constexpr const char* GREATER_EQUALS = "_>=_"; + static constexpr const char* ADD = "_+_"; + static constexpr const char* SUBTRACT = "_-_"; + static constexpr const char* MULTIPLY = "_*_"; + static constexpr const char* DIVIDE = "_/_"; + static constexpr const char* MODULO = "_%_"; + static constexpr const char* NEGATE = "-_"; + static constexpr const char* INDEX = "_[_]"; + // Macros + static constexpr const char* HAS = "has"; + static constexpr const char* ALL = "all"; + static constexpr const char* EXISTS = "exists"; + static constexpr const char* EXISTS_ONE = "exists_one"; + static constexpr const char* MAP = "map"; + static constexpr const char* FILTER = "filter"; + + // Named operators, must not have be valid identifiers. + static constexpr const char* NOT_STRICTLY_FALSE = "@not_strictly_false"; + static constexpr const char* IN = "@in"; +}; + +// These give access to all or some specific precedence value. +// Higher value means higher precedence, 0 means no precedence, i.e., +// custom function and not builtin operator. +int LookupPrecedence(const std::string& op); + +std::optional LookupUnaryOperator(const std::string& op); +std::optional LookupBinaryOperator(const std::string& op); +std::optional LookupOperator(const std::string& op); +std::optional ReverseLookupOperator(const std::string& op); + +// returns true if op has a lower precedence than the one expressed in expr +bool IsOperatorLowerPrecedence(const std::string& op, const Expr& expr); +// returns true if op has the same precedence as the one expressed in expr +bool IsOperatorSamePrecedence(const std::string& op, const Expr& expr); +// return true if operator is left recursive, i.e., neither && nor ||. +bool IsOperatorLeftRecursive(const std::string& op); + +} // namespace common +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_COMMON_OPERATORS_H_ diff --git a/conformance/BUILD b/conformance/BUILD index 8c8a5366a..64a8a03ba 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -6,18 +6,14 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 ALL_TESTS = [ - # Requires container support - #"@com_google_cel_spec//tests/simple:testdata/basic.textproto", + "@com_google_cel_spec//tests/simple:testdata/basic.textproto", "@com_google_cel_spec//tests/simple:testdata/fp_math.textproto", - # Overflow is not handled - # "@com_google_cel_spec//tests/simple:testdata/integer_math.textproto", + "@com_google_cel_spec//tests/simple:testdata/integer_math.textproto", "@com_google_cel_spec//tests/simple:testdata/logic.textproto", "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", "@com_google_cel_spec//tests/simple:testdata/string.textproto", - # Requires heteregenous equality spec clarification - #"@com_google_cel_spec//tests/simple:testdata/comparisons.textproto", - # Requires qualified bindings error message relaxation - #"@com_google_cel_spec//tests/simple:testdata/fields.textproto", + "@com_google_cel_spec//tests/simple:testdata/comparisons.textproto", + "@com_google_cel_spec//tests/simple:testdata/fields.textproto", ] DASHBOARD_TESTS = [ @@ -58,6 +54,16 @@ cc_binary( "$(location @com_google_cel_spec//tests/simple:simple_test)", "--server=$(location @com_google_cel_go//server/main:cel_server)", "--eval_server=$(location :server)", + # Requires container support + "--skip_test=basic/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", + # Requires heteregenous equality spec clarification + "--skip_test=comparisons/eq_literal/eq_bytes", + "--skip_test=comparisons/ne_literal/not_ne_bytes", + "--skip_test=comparisons/in_list_literal/elem_in_mixed_type_list_error", + "--skip_test=comparisons/in_map_literal/key_in_mixed_key_type_map_error", + # Requires qualified bindings error message relaxation + "--skip_test=fields/qualified_identifier_resolution/ident_with_longest_prefix_check,int64_field_select_unsupported,list_field_select_unsupported,map_key_null,qualified_identifier_resolution_unchecked", + "--skip_test=integer_math/int64_math/int64_overflow_positive,int64_overflow_negative,uint64_overflow_positive,uint64_overflow_negative", ] + ["$(location " + test + ")" for test in ALL_TESTS], data = [ ":server", diff --git a/conformance/server.cc b/conformance/server.cc index 96334fc20..69cdd225a 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -250,7 +250,6 @@ class ConformanceServiceImpl final int RunServer(std::string server_address) { google::protobuf::Arena arena; InterpreterOptions options; - options.partial_string_match = true; const char* enable_constant_folding = getenv("CEL_CPP_ENABLE_CONSTANT_FOLDING"); diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index a5b9d72a3..1652ceb4f 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -23,6 +23,7 @@ cc_library( "//base:status", "//eval/eval:comprehension_step", "//eval/eval:const_value_step", + "//eval/eval:container_access_step", "//eval/eval:create_list_step", "//eval/eval:create_struct_step", "//eval/eval:evaluator_core", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 5a1006dbb..91109ba6a 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -7,6 +7,7 @@ #include "eval/compiler/constant_folding.h" #include "eval/eval/comprehension_step.h" #include "eval/eval/const_value_step.h" +#include "eval/eval/container_access_step.h" #include "eval/eval/create_list_step.h" #include "eval/eval/create_struct_step.h" #include "eval/eval/evaluator_core.h" @@ -46,13 +47,15 @@ class FlatExprVisitor : public AstVisitor { bool shortcircuiting, const std::set& enums, absl::string_view container, - const absl::flat_hash_map& constant_idents) + const absl::flat_hash_map& constant_idents, + bool enable_comprehension) : flattened_path_(path), progress_status_(cel_base::OkStatus()), resolved_select_expr_(nullptr), function_registry_(function_registry), shortcircuiting_(shortcircuiting), - constant_idents_(constant_idents) { + constant_idents_(constant_idents), + enable_comprehension_(enable_comprehension) { auto container_elements = absl::StrSplit(container, '.'); // Build list of prefixes from container. Non-empty prefixes must end with @@ -257,6 +260,11 @@ class FlatExprVisitor : public AstVisitor { cond_visitor->PostVisit(expr); cond_visitor_stack_.pop(); } else { + // Special case for "_[_]". + if (call_expr->function() == builtin::kIndex) { + AddStep(CreateContainerAccessStep(call_expr, expr->id())); + return; + } // For regular functions, just create one based on registry. AddStep(CreateFunctionStep(call_expr, expr->id(), *function_registry_)); } @@ -268,6 +276,10 @@ class FlatExprVisitor : public AstVisitor { if (!progress_status_.ok()) { return; } + if (!enable_comprehension_) { + SetProgressStatusError(cel_base::Status(cel_base::StatusCode::kInvalidArgument, + "Comprehension support is disabled")); + } cond_visitor_stack_.emplace(expr, absl::make_unique(this)); auto cond_visitor = FindCondVisitor(expr); @@ -445,6 +457,8 @@ class FlatExprVisitor : public AstVisitor { bool shortcircuiting_; const absl::flat_hash_map& constant_idents_; + + bool enable_comprehension_; }; void FlatExprVisitor::BinaryCondVisitor::PreVisit(const Expr* expr) {} @@ -652,7 +666,7 @@ FlatExprBuilder::CreateExpression(const Expr* expr, FlatExprVisitor visitor(this->GetRegistry(), &execution_path, shortcircuiting_, resolvable_enums(), container(), - idents); + idents, enable_comprehension_); AstTraverse(constant_folding_ ? &out : expr, source_info, &visitor); @@ -661,7 +675,8 @@ FlatExprBuilder::CreateExpression(const Expr* expr, } std::unique_ptr expression_impl = - absl::make_unique(expr, std::move(execution_path)); + absl::make_unique(expr, std::move(execution_path), + comprehension_max_iterations_); return std::move(expression_impl); } diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 65ef91bd8..48beaee18 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -16,7 +16,9 @@ class FlatExprBuilder : public CelExpressionBuilder { FlatExprBuilder() : shortcircuiting_(true), constant_folding_(false), - constant_arena_(nullptr) {} + constant_arena_(nullptr), + enable_comprehension_(true), + comprehension_max_iterations_(0) {} // set_shortcircuiting regulates shortcircuiting of some expressions. // Be default shortcircuiting is enabled. @@ -29,6 +31,14 @@ class FlatExprBuilder : public CelExpressionBuilder { constant_arena_ = arena; } + void set_enable_comprehension(bool enabled) { + enable_comprehension_ = enabled; + } + + void set_comprehension_max_iterations(int max_iterations) { + comprehension_max_iterations_ = max_iterations; + } + cel_base::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; @@ -38,6 +48,8 @@ class FlatExprBuilder : public CelExpressionBuilder { bool constant_folding_; google::protobuf::Arena* constant_arena_; + bool enable_comprehension_; + int comprehension_max_iterations_; }; } // namespace runtime diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index ce8f45be3..48f6a725d 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -445,6 +445,58 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { Eq("No matching overloads found")); } +TEST(FlatExprBuilderTest, ComprehensionBudget) { + Expr expr; + // [1, 2].all(x, x > 0) + google::protobuf::TextFormat::ParseFromString(R"( + comprehension_expr { + iter_var: "k" + accu_var: "accu" + accu_init { + const_expr { bool_value: true } + } + loop_condition { ident_expr { name: "accu" } } + result { ident_expr { name: "accu" } } + loop_step { + call_expr { + function: "_&&_" + args { + ident_expr { name: "accu" } + } + args { + call_expr { + function: "_>_" + args { ident_expr { name: "k" } } + args { const_expr { int64_value: 0 } } + } + } + } + } + iter_range { + list_expr { + { const_expr { int64_value: 1 } } + { const_expr { int64_value: 2 } } + } + } + })", + &expr); + + FlatExprBuilder builder; + builder.set_comprehension_max_iterations(1); + ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + SourceInfo source_info; + auto build_status = builder.CreateExpression(&expr, &source_info); + ASSERT_TRUE(build_status.ok()); + + auto cel_expr = std::move(build_status.ValueOrDie()); + + Activation activation; + google::protobuf::Arena arena; + auto result_or = cel_expr->Evaluate(activation, &arena); + ASSERT_FALSE(result_or.ok()); + EXPECT_THAT(result_or.status().message(), Eq("Iteration budget exceeded")); +} + TEST(FlatExprBuilderTest, UnknownSupportTest) { TestMessage message; diff --git a/eval/eval/BUILD b/eval/eval/BUILD index e04f3d67e..e9652556c 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -59,6 +59,25 @@ cc_library( ], ) +cc_library( + name = "container_access_step", + srcs = [ + "container_access_step.cc", + ], + hdrs = [ + "container_access_step.h", + ], + deps = [ + ":evaluator_core", + ":expression_step_base", + "//base:status", + "//eval/public:activation", + "//eval/public:cel_value", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "ident_step", srcs = [ @@ -106,6 +125,7 @@ cc_library( deps = [ "//base:status", "//eval/public:cel_value", + "//internal:proto_util", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -328,6 +348,24 @@ cc_test( ], ) +cc_test( + name = "container_access_step_test", + size = "small", + srcs = [ + "container_access_step_test.cc", + ], + deps = [ + ":container_access_step", + ":container_backed_list_impl", + ":container_backed_map_impl", + ":ident_step", + "//eval/public:cel_builtins", + "//eval/public:cel_value", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "ident_step_test", size = "small", diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index 8e80a0dbf..402ae3fb1 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -97,6 +97,10 @@ cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { ); return cel_base::Status(cel_base::StatusCode::kInternal, message); } + auto increment_status = frame->IncrementIterations(); + if (!increment_status.ok()) { + return increment_status; + } int64_t current_index = current_index_value.Int64OrDie(); CelValue loop_step = state[POS_LOOP_STEP]; frame->value_stack().Pop(5); diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index dda20daac..d2da6a773 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -31,7 +31,7 @@ cel_base::StatusOr RunConstantExpression(const Expr* expr, google::api::expr::v1alpha1::Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path)); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0); Activation activation; diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc new file mode 100644 index 000000000..fa383f1fe --- /dev/null +++ b/eval/eval/container_access_step.cc @@ -0,0 +1,135 @@ +#include "eval/eval/container_access_step.h" + +#include "google/protobuf/arena.h" +#include "absl/strings/str_cat.h" +#include "eval/eval/expression_step_base.h" +#include "eval/public/cel_value.h" +#include "base/status.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +// ContainerAccessStep performs message field access specified by Expr::Select +// message. +class ContainerAccessStep : public ExpressionStepBase { + public: + ContainerAccessStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + + cel_base::Status Evaluate(ExecutionFrame* frame) const override; + + private: + CelValue PerformLookup(const CelValue& container, const CelValue& key, + google::protobuf::Arena* arena) const; + CelValue LookupInMap(const CelMap* cel_map, const CelValue& key, + google::protobuf::Arena* arena) const; + CelValue LookupInList(const CelList* cel_list, const CelValue& key, + google::protobuf::Arena* arena) const; +}; + +inline CelValue ContainerAccessStep::LookupInMap(const CelMap* cel_map, + const CelValue& key, + google::protobuf::Arena* arena) const { + switch (key.type()) { + case CelValue::Type::kBool: + case CelValue::Type::kInt64: + case CelValue::Type::kUint64: + case CelValue::Type::kString: { + absl::optional maybe_value = (*cel_map)[key]; + if (maybe_value.has_value()) { + return maybe_value.value(); + } + break; + } + default: { + break; + } + } + return CreateNoSuchKeyError(arena, absl::StrCat("Key not found in map")); +} + +inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, + const CelValue& key, + google::protobuf::Arena* arena) const { + switch (key.type()) { + case CelValue::Type::kInt64: { + int64_t idx = key.Int64OrDie(); + if (idx < 0 || idx >= cel_list->size()) { + return CreateErrorValue(arena, + absl::StrCat("Index error: index=", idx, + " size=", cel_list->size())); + } + return (*cel_list)[idx]; + } + default: { + return CreateErrorValue( + arena, absl::StrCat("Index error: expected integer type, got ", + CelValue::TypeName(key.type()))); + } + } +} + +CelValue ContainerAccessStep::PerformLookup(const CelValue& container, + const CelValue& key, + google::protobuf::Arena* arena) const { + if (container.IsError()) { + return container; + } + if (key.IsError()) { + return key; + } + // Select steps can be applied to either maps or messages + switch (container.type()) { + case CelValue::Type::kMap: { + const CelMap* cel_map = container.MapOrDie(); + return LookupInMap(cel_map, key, arena); + } + case CelValue::Type::kList: { + const CelList* cel_list = container.ListOrDie(); + return LookupInList(cel_list, key, arena); + } + default: { + return CreateErrorValue( + arena, absl::StrCat("Unexpected container type for [] operation: ", + CelValue::TypeName(key.type()))); + } + } +} + +cel_base::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { + const int NUM_ARGUMENTS = 2; + + if (!frame->value_stack().HasEnough(NUM_ARGUMENTS)) { + return cel_base::Status( + cel_base::StatusCode::kInternal, + "Insufficient arguments supplied for ContainerAccess-type expression"); + } + + auto input_args = frame->value_stack().GetSpan(NUM_ARGUMENTS); + + const CelValue& container = input_args[0]; + const CelValue& key = input_args[1]; + + CelValue result = PerformLookup(container, key, frame->arena()); + frame->value_stack().Pop(NUM_ARGUMENTS); + frame->value_stack().Push(result); + + return cel_base::OkStatus(); +} +} // namespace + +// Factory method for Select - based Execution step +cel_base::StatusOr> CreateContainerAccessStep( + const google::api::expr::v1alpha1::Expr::Call* call_expr, int64_t expr_id) { + std::unique_ptr step = + absl::make_unique(expr_id); + return std::move(step); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/container_access_step.h b/eval/eval/container_access_step.h new file mode 100644 index 000000000..f0877fff8 --- /dev/null +++ b/eval/eval/container_access_step.h @@ -0,0 +1,22 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ + +#include "eval/eval/evaluator_core.h" +#include "eval/public/activation.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Factory method for Select - based Execution step +cel_base::StatusOr> CreateContainerAccessStep( + const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id); + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CONTAINER_ACCESS_STEP_H_ diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc new file mode 100644 index 000000000..22a13fd62 --- /dev/null +++ b/eval/eval/container_access_step_test.cc @@ -0,0 +1,149 @@ +#include "eval/eval/container_access_step.h" + +#include +#include +#include + +#include "google/protobuf/struct.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/eval/container_backed_list_impl.h" +#include "eval/eval/container_backed_map_impl.h" +#include "eval/eval/ident_step.h" +#include "eval/public/cel_builtins.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using ::google::protobuf::Struct; + +using google::api::expr::v1alpha1::Expr; +using google::api::expr::v1alpha1::SourceInfo; + +class ContainerAccessStepTest : public ::testing::Test { + protected: + ContainerAccessStepTest() {} + + void SetUp() override {} + + // Helper method. Looks up in registry and tests comparison operation. + CelValue PerformRun(CelValue container, CelValue key, bool receiver_style) { + ExecutionPath path; + + Expr expr; + SourceInfo source_info; + auto call = expr.mutable_call_expr(); + + call->set_function(builtin::kIndex); + + Expr* container_expr = + (receiver_style) ? call->mutable_target() : call->add_args(); + Expr* key_expr = call->add_args(); + + container_expr->mutable_ident_expr()->set_name("container"); + key_expr->mutable_ident_expr()->set_name("key"); + + path.push_back(std::move( + CreateIdentStep(&container_expr->ident_expr(), 1).ValueOrDie())); + path.push_back( + std::move(CreateIdentStep(&key_expr->ident_expr(), 2).ValueOrDie())); + path.push_back(std::move(CreateContainerAccessStep(call, 3).ValueOrDie())); + + CelExpressionFlatImpl cel_expr(&expr, std::move(path), 0); + Activation activation; + + activation.InsertValue("container", container); + activation.InsertValue("key", key); + auto eval_status = cel_expr.Evaluate(activation, &arena_); + + EXPECT_TRUE(eval_status.ok()); + return eval_status.ValueOrDie(); + } + google::protobuf::Arena arena_; +}; + +TEST_F(ContainerAccessStepTest, TestListIndexAccess) { + ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3)}); + + CelValue result = PerformRun(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), true); + + ASSERT_TRUE(result.IsInt64()); + ASSERT_EQ(result.Int64OrDie(), 2); +} + +TEST_F(ContainerAccessStepTest, TestListIndexAccessOutOfBounds) { + ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3)}); + + CelValue result = PerformRun(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(0), true); + + ASSERT_TRUE(result.IsInt64()); + result = PerformRun(CelValue::CreateList(&cel_list), CelValue::CreateInt64(2), + true); + + ASSERT_TRUE(result.IsInt64()); + result = PerformRun(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(-1), true); + + ASSERT_TRUE(result.IsError()); + result = PerformRun(CelValue::CreateList(&cel_list), CelValue::CreateInt64(3), + true); + + ASSERT_TRUE(result.IsError()); +} + +TEST_F(ContainerAccessStepTest, TestListIndexAccessNotAnInt) { + ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3)}); + + CelValue result = PerformRun(CelValue::CreateList(&cel_list), + CelValue::CreateUint64(1), true); + + ASSERT_TRUE(result.IsError()); +} + +TEST_F(ContainerAccessStepTest, TestMapKeyAccess) { + const std::string kKey0 = "testkey0"; + const std::string kKey1 = "testkey1"; + const std::string kKey2 = "testkey2"; + Struct cel_struct; + (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); + (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); + (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); + + CelValue result = PerformRun(CelValue::CreateMessage(&cel_struct, &arena_), + CelValue::CreateString(&kKey0), true); + + ASSERT_TRUE(result.IsString()); + ASSERT_EQ(result.StringOrDie().value(), "value0"); +} + +TEST_F(ContainerAccessStepTest, TestMapKeyAccessNotFound) { + const std::string kKey0 = "testkey0"; + const std::string kKey1 = "testkey1"; + Struct cel_struct; + (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); + + CelValue result = PerformRun(CelValue::CreateMessage(&cel_struct, &arena_), + CelValue::CreateString(&kKey1), true); + + ASSERT_TRUE(result.IsError()); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index 7af1b4949..1b0ce3660 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -42,7 +42,7 @@ cel_base::StatusOr RunExpression(const std::vector& values, path.push_back(std::move(step0_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path)); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); Activation activation; return cel_expr.Evaluate(activation, arena); @@ -64,7 +64,7 @@ TEST(CreateListStepTest, TestCreateListStackUndeflow) { path.push_back(std::move(step0_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path)); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index 9108161c8..c8fef5922 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -61,7 +61,7 @@ cel_base::StatusOr RunExpression(absl::string_view field, path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path)); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); Activation activation; activation.InsertValue("message", value); @@ -158,7 +158,7 @@ cel_base::StatusOr RunCreateMapExpression( path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path)); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); return cel_expr.Evaluate(activation, arena); } @@ -176,7 +176,7 @@ TEST(CreateCreateStructStepTest, TestEmptyMessageCreation) { path.push_back(std::move(step_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path)); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); Activation activation; google::protobuf::Arena arena; @@ -293,6 +293,36 @@ TEST(CreateCreateStructStepTest, TestSetBytesField) { EXPECT_EQ(test_msg.bytes_value(), kTestStr); } +// Test that fields of type duration are set correctly. +TEST(CreateCreateStructStepTest, TestSetDurationField) { + Arena arena; + + google::protobuf::Duration test_duration; + test_duration.set_seconds(2); + test_duration.set_nanos(3); + TestMessage test_msg; + + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "duration_value", CelValue::CreateDuration(&test_duration), &arena, + &test_msg)); + EXPECT_THAT(test_msg.duration_value(), EqualsProto(test_duration)); +} + +// Test that fields of type timestamp are set correctly. +TEST(CreateCreateStructStepTest, TestSetTimestampField) { + Arena arena; + + google::protobuf::Timestamp test_timestamp; + test_timestamp.set_seconds(2); + test_timestamp.set_nanos(3); + TestMessage test_msg; + + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "timestamp_value", CelValue::CreateTimestamp(&test_timestamp), &arena, + &test_msg)); + EXPECT_THAT(test_msg.timestamp_value(), EqualsProto(test_timestamp)); +} + // Test that fields of type Message are set correctly. TEST(CreateCreateStructStepTest, TestSetMessageField) { Arena arena; diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 2edc7994d..60500a397 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -25,7 +25,7 @@ cel_base::StatusOr CelExpressionFlatImpl::Evaluate( cel_base::StatusOr CelExpressionFlatImpl::Trace( const Activation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback) const { - ExecutionFrame frame(&path_, activation, arena); + ExecutionFrame frame(&path_, activation, arena, max_iterations_); ValueStack* stack = &frame.value_stack(); size_t initial_stack_size = stack->size(); diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 016bb92b7..c7b189fd5 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -96,8 +96,13 @@ class ExecutionFrame { // activation provides bindings between parameter names and values. // arena serves as allocation manager during the expression evaluation. ExecutionFrame(const ExecutionPath* flat, const Activation& activation, - google::protobuf::Arena* arena) - : pc_(0), execution_path_(flat), activation_(activation), arena_(arena) { + google::protobuf::Arena* arena, int max_iterations) + : pc_(0), + execution_path_(flat), + activation_(activation), + arena_(arena), + max_iterations_(max_iterations), + iterations_(0) { // Reserve space on stack to minimize reallocations // on stack resize. value_stack_.Reserve(flat->size()); @@ -129,12 +134,28 @@ class ExecutionFrame { // Returns reference to iter_vars std::map& iter_vars() { return iter_vars_; } + // Increment iterations and return an error if the iteration budget is + // exceeded + cel_base::Status IncrementIterations() { + if (max_iterations_ == 0) { + return cel_base::OkStatus(); + } + iterations_++; + if (iterations_ >= max_iterations_) { + return cel_base::Status(cel_base::StatusCode::kInternal, + "Iteration budget exceeded"); + } + return cel_base::OkStatus(); + } + private: int pc_; // pc_ - Program Counter. Current position on execution path. const ExecutionPath* execution_path_; const Activation& activation_; ValueStack value_stack_; google::protobuf::Arena* arena_; + const int max_iterations_; + int iterations_; std::map iter_vars_; // variables declared in the frame. }; @@ -145,10 +166,12 @@ class CelExpressionFlatImpl : public CelExpression { // Constructs CelExpressionFlatImpl instance. // root_expr represents the root of AST tree; // path is flat execution path that is based upon - // flattened AST tree. + // flattened AST tree. Max iterations dictates the maximum number of + // iterations in the comprehension expressions (use 0 to disable the upper + // bound). CelExpressionFlatImpl(const google::api::expr::v1alpha1::Expr* root_expr, - ExecutionPath path) - : path_(std::move(path)) {} + ExecutionPath path, int max_iterations) + : path_(std::move(path)), max_iterations_(max_iterations) {} // Implementation of CelExpression evaluate method. cel_base::StatusOr Evaluate(const Activation& activation, @@ -161,6 +184,7 @@ class CelExpressionFlatImpl : public CelExpression { private: const ExecutionPath path_; + const int max_iterations_; }; } // namespace runtime diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index ce51df9a1..d8eae3cf5 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -60,7 +60,7 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { auto dummy_expr = absl::make_unique(); Activation activation; - ExecutionFrame frame(&path, activation, nullptr); + ExecutionFrame frame(&path, activation, nullptr, 0); EXPECT_THAT(frame.Next(), Eq(path[0].get())); EXPECT_THAT(frame.Next(), Eq(path[1].get())); @@ -80,7 +80,7 @@ TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/field_access.cc b/eval/eval/field_access.cc index 0dcf4555c..e7921a4f4 100644 --- a/eval/eval/field_access.cc +++ b/eval/eval/field_access.cc @@ -5,6 +5,7 @@ #include "google/protobuf/map_field.h" #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" +#include "internal/proto_util.h" #include "base/canonical_errors.h" namespace google { @@ -20,6 +21,10 @@ using ::google::protobuf::MapValueRef; using ::google::protobuf::Message; using ::google::protobuf::Reflection; +// Well-known type protobuf type names which require special get / set behavior. +constexpr const char kProtobufDuration[] = "google.protobuf.Duration"; +constexpr const char kProtobufTimestamp[] = "google.protobuf.Timestamp"; + // Singular message fields and repeated message fields have similar access model // To provide common approach, we implement accessor classes, based on CRTP. // FieldAccessor is CRTP base class, specifying Get.. method family. @@ -454,6 +459,30 @@ class FieldSetter { return true; } + bool AssignDuration(const CelValue& cel_value) const { + absl::Duration d; + if (!cel_value.GetValue(&d)) { + GOOGLE_LOG(ERROR) << "Unable to retrieve duration"; + return false; + } + google::protobuf::Duration duration; + google::api::expr::internal::EncodeDuration(d, &duration); + static_cast(this)->SetMessage(&duration); + return true; + } + + bool AssignTimestamp(const CelValue& cel_value) const { + absl::Time t; + if (!cel_value.GetValue(&t)) { + GOOGLE_LOG(ERROR) << "Unable to retrieve timestamp"; + return false; + } + google::protobuf::Timestamp timestamp; + google::api::expr::internal::EncodeTime(t, ×tamp); + static_cast(this)->SetMessage(×tamp); + return true; + } + // This method provides message field content, wrapped in CelValue. // If value provided successfully, returns Ok. // arena Arena to use for allocations if needed. @@ -494,6 +523,15 @@ class FieldSetter { break; } case FieldDescriptor::CPPTYPE_MESSAGE: { + const std::string& type_name = field_desc_->message_type()->full_name(); + // When the field is a message, it might be a well-known type with a + // non-proto representation that requires special handling before it + // can be set on the field. + if (type_name == kProtobufTimestamp) { + return AssignTimestamp(value); + } else if (type_name == kProtobufDuration) { + return AssignDuration(value); + } return AssignMessage(value); } case FieldDescriptor::CPPTYPE_ENUM: { diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index defa6bfb2..e61fd6f7a 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -92,7 +92,7 @@ TEST(FunctionStepTest, SimpleFunctionTest) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; google::protobuf::Arena arena; @@ -127,7 +127,7 @@ TEST(FunctionStepTest, TestStackUnderflow) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; google::protobuf::Arena arena; @@ -163,7 +163,7 @@ TEST(FunctionStepTest, TestMultipleOverloads) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; google::protobuf::Arena arena; @@ -200,7 +200,7 @@ TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; google::protobuf::Arena arena; @@ -247,7 +247,7 @@ TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; google::protobuf::Arena arena; diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index 9d99157c3..f2d898011 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -31,7 +31,7 @@ TEST(IdentStepTest, TestIdentStep) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; Arena arena; @@ -60,7 +60,7 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; Arena arena; @@ -86,7 +86,7 @@ TEST(IdentStepTest, TestIdentStepUnknownValue) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path)); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); Activation activation; Arena arena; diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index c2b6c0d38..4b679f5d0 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -50,7 +50,7 @@ cel_base::StatusOr RunExpression(const CelValue target, path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path)); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); Activation activation; activation.InsertValue("target", target); @@ -416,7 +416,7 @@ TEST(SelectStepTest, CelErrorAsArgument) { CelError error; google::protobuf::Arena arena; - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path)); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); @@ -453,7 +453,7 @@ TEST(SelectStepTest, UnknownValueProducesError) { path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path)); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); Activation activation; activation.InsertValue("message", CelValue::CreateMessage(&message, &arena)); diff --git a/eval/public/BUILD b/eval/public/BUILD index 92d4b7958..e8f62d667 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -88,6 +88,7 @@ cc_library( "cel_function.h", ], deps = [ + ":cel_options", ":cel_value", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", @@ -132,6 +133,7 @@ cc_library( ":cel_builtins", ":cel_function", ":cel_function_adapter", + ":cel_options", "//base:status", "//eval/eval:container_backed_list_impl", "@com_google_absl//absl/strings", @@ -220,6 +222,16 @@ cc_library( ], ) +cc_library( + name = "cel_options", + hdrs = [ + "cel_options.h", + ], + deps = [ + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "cel_expr_builder_factory", srcs = [ @@ -230,10 +242,27 @@ cc_library( ], deps = [ ":cel_expression", + ":cel_options", "//eval/compiler:flat_expr_builder", ], ) +cc_library( + name = "value_export_util", + srcs = [ + "value_export_util.cc", + ], + hdrs = [ + "value_export_util.h", + ], + deps = [ + ":cel_value", + "//base:status", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "cel_value_test", size = "small", @@ -347,3 +376,23 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "value_export_util_test", + size = "small", + srcs = [ + "value_export_util_test.cc", + ], + deps = [ + ":cel_value", + ":value_export_util", + "//base:status", + "//eval/eval:container_backed_list_impl", + "//eval/eval:container_backed_map_impl", + "//eval/testutil:test_message_cc_proto", + "//testutil:util", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index a918e611e..dcc7d0268 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -66,6 +66,7 @@ struct StackRecord { void PreVisit(const StackRecord &record, AstVisitor *visitor) { const Expr *expr = record.expr; const SourcePosition position(expr->id(), record.source_info); + visitor->PreVisitExpr(expr, &position); switch (expr->expr_kind_case()) { case Expr::kSelectExpr: visitor->PreVisitSelect(&expr->select_expr(), expr, &position); @@ -117,6 +118,7 @@ void PostVisit(const StackRecord &record, AstVisitor *visitor) { record.calling_expr != nullptr) { visitor->PostVisitArg(record.call_arg, record.calling_expr, &position); } + visitor->PostVisitExpr(expr, &position); } void PushSelectDeps(const Select *select_expr, const StackRecord &record, diff --git a/eval/public/ast_traverse_test.cc b/eval/public/ast_traverse_test.cc index f33b9bb2a..d36954f42 100644 --- a/eval/public/ast_traverse_test.cc +++ b/eval/public/ast_traverse_test.cc @@ -37,6 +37,14 @@ using Comprehension = google::api::expr::v1alpha1::Expr::Comprehension; class MockAstVisitor : public AstVisitor { public: + // Expr handler. + MOCK_METHOD2(PreVisitExpr, + void(const Expr* expr, const SourcePosition* position)); + + // Expr handler. + MOCK_METHOD2(PostVisitExpr, + void(const Expr* expr, const SourcePosition* position)); + MOCK_METHOD3(PostVisitConst, void(const Constant* const_expr, const Expr* expr, const SourcePosition* position)); @@ -247,6 +255,24 @@ TEST(AstCrawlerTest, CheckCreateStruct) { AstTraverse(&expr, &source_info, &handler); } +// Test generic Expr handlers. +TEST(AstCrawlerTest, CheckExprHandlers) { + SourceInfo source_info; + MockAstVisitor handler; + + Expr expr; + auto struct_expr = expr.mutable_struct_expr(); + auto entry0 = struct_expr->add_entries(); + + entry0->mutable_map_key()->mutable_const_expr(); + entry0->mutable_value()->mutable_ident_expr(); + + EXPECT_CALL(handler, PreVisitExpr(_, _)).Times(3); + EXPECT_CALL(handler, PostVisitExpr(_, _)).Times(3); + + AstTraverse(&expr, &source_info, &handler); +} + } // namespace } // namespace runtime diff --git a/eval/public/ast_visitor.h b/eval/public/ast_visitor.h index 5f2d3f6bd..5bed85373 100644 --- a/eval/public/ast_visitor.h +++ b/eval/public/ast_visitor.h @@ -45,6 +45,20 @@ class AstVisitor { public: virtual ~AstVisitor() {} + // Expr node handler method. Called for all Expr nodes. + // Is invoked before child Expr nodes being processed. + // TODO(issues/22): this method is not pure virtual to avoid dependencies + // breakage. Change it in subsequent CLs. + virtual void PreVisitExpr(const google::api::expr::v1alpha1::Expr* expr, + const SourcePosition* position) {} + + // Expr node handler method. Called for all Expr nodes. + // Is invoked after child Expr nodes are processed. + // TODO(issues/22): this method is not pure virtual to avoid dependencies + // breakage. Change it in subsequent CLs. + virtual void PostVisitExpr(const google::api::expr::v1alpha1::Expr* expr, + const SourcePosition* position) {} + // Const node handler. // Invoked after child nodes are processed. virtual void PostVisitConst(const google::api::expr::v1alpha1::Constant* const_expr, diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 2ee4f5cc9..0b5a5a6b4 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -25,26 +25,26 @@ namespace { // Comparison template functions template -CelValue Inequal(Arena* arena, Type t1, Type t2) { +CelValue Inequal(Arena*, Type t1, Type t2) { return CelValue::CreateBool(t1 != t2); } template -CelValue Equal(Arena* arena, Type t1, Type t2) { +CelValue Equal(Arena*, Type t1, Type t2) { return CelValue::CreateBool(t1 == t2); } // Forward declaration of the generic equality operator template <> -CelValue Equal(Arena* arena, CelValue t1, CelValue t2); +CelValue Equal(Arena*, CelValue t1, CelValue t2); template -bool LessThan(Arena* arena, Type t1, Type t2) { +bool LessThan(Arena*, Type t1, Type t2) { return (t1 < t2); } template -bool LessThanOrEqual(Arena* arena, Type t1, Type t2) { +bool LessThanOrEqual(Arena*, Type t1, Type t2) { return (t1 <= t2); } @@ -60,63 +60,63 @@ bool GreaterThanOrEqual(Arena* arena, Type t1, Type t2) { // Duration comparison specializations template <> -CelValue Inequal(Arena* arena, absl::Duration t1, absl::Duration t2) { +CelValue Inequal(Arena*, absl::Duration t1, absl::Duration t2) { return CelValue::CreateBool(operator!=(t1, t2)); } template <> -CelValue Equal(Arena* arena, absl::Duration t1, absl::Duration t2) { +CelValue Equal(Arena*, absl::Duration t1, absl::Duration t2) { return CelValue::CreateBool(operator==(t1, t2)); } template <> -bool LessThan(Arena* arena, absl::Duration t1, absl::Duration t2) { +bool LessThan(Arena*, absl::Duration t1, absl::Duration t2) { return operator<(t1, t2); } template <> -bool LessThanOrEqual(Arena* arena, absl::Duration t1, absl::Duration t2) { +bool LessThanOrEqual(Arena*, absl::Duration t1, absl::Duration t2) { return operator<=(t1, t2); } template <> -bool GreaterThan(Arena* arena, absl::Duration t1, absl::Duration t2) { +bool GreaterThan(Arena*, absl::Duration t1, absl::Duration t2) { return operator>(t1, t2); } template <> -bool GreaterThanOrEqual(Arena* arena, absl::Duration t1, absl::Duration t2) { +bool GreaterThanOrEqual(Arena*, absl::Duration t1, absl::Duration t2) { return operator>=(t1, t2); } // Timestamp comparison specializations template <> -CelValue Inequal(Arena* arena, absl::Time t1, absl::Time t2) { +CelValue Inequal(Arena*, absl::Time t1, absl::Time t2) { return CelValue::CreateBool(operator!=(t1, t2)); } template <> -CelValue Equal(Arena* arena, absl::Time t1, absl::Time t2) { +CelValue Equal(Arena*, absl::Time t1, absl::Time t2) { return CelValue::CreateBool(operator==(t1, t2)); } template <> -bool LessThan(Arena* arena, absl::Time t1, absl::Time t2) { +bool LessThan(Arena*, absl::Time t1, absl::Time t2) { return operator<(t1, t2); } template <> -bool LessThanOrEqual(Arena* arena, absl::Time t1, absl::Time t2) { +bool LessThanOrEqual(Arena*, absl::Time t1, absl::Time t2) { return operator<=(t1, t2); } template <> -bool GreaterThan(Arena* arena, absl::Time t1, absl::Time t2) { +bool GreaterThan(Arena*, absl::Time t1, absl::Time t2) { return operator>(t1, t2); } template <> -bool GreaterThanOrEqual(Arena* arena, absl::Time t1, absl::Time t2) { +bool GreaterThanOrEqual(Arena*, absl::Time t1, absl::Time t2) { return operator>=(t1, t2); } @@ -518,16 +518,13 @@ CelValue GetTimeBreakdownPart( CelValue CreateTimestampFromString(Arena* arena, CelValue::StringHolder time_str) { - Timestamp ts; - auto result = - google::protobuf::util::TimeUtil::FromString(std::string(time_str.value()), &ts); - if (!result) { + absl::Time ts; + if (!absl::ParseTime(absl::RFC3339_full, std::string(time_str.value()), &ts, + nullptr)) { return CreateErrorValue(arena, "String to Timestamp conversion failed", cel_base::StatusCode::kInvalidArgument); } - - return CelValue::CreateTimestamp( - Arena::Create(arena, std::move(ts))); + return CelValue::CreateTimestamp(ts); } CelValue GetFullYear(Arena* arena, absl::Time timestamp, absl::string_view tz) { @@ -613,16 +610,14 @@ CelValue GetMilliseconds(Arena* arena, absl::Time timestamp, } CelValue CreateDurationFromString(Arena* arena, - CelValue::StringHolder time_str) { - Duration d; - auto result = - google::protobuf::util::TimeUtil::FromString(std::string(time_str.value()), &d); - if (!result) { + CelValue::StringHolder dur_str) { + absl::Duration d; + if (!absl::ParseDuration(std::string(dur_str.value()), &d)) { return CreateErrorValue(arena, "String to Duration conversion failed", cel_base::StatusCode::kInvalidArgument); } - return CelValue::CreateDuration(Arena::Create(arena, std::move(d))); + return CelValue::CreateDuration(d); } CelValue GetHours(Arena* arena, absl::Duration duration) { @@ -643,26 +638,6 @@ CelValue GetMilliseconds(Arena* arena, absl::Duration duration) { millis_per_second); } -CelValue RegexMatchesFull(Arena* arena, CelValue::StringHolder target, - CelValue::StringHolder regex) { - RE2 re2(regex.value().data()); - if (!re2.ok()) { - return CreateErrorValue(arena, "invalid_argument", - cel_base::StatusCode::kInvalidArgument); - } - return CelValue::CreateBool(RE2::FullMatch(re2::StringPiece(target.value().data(), target.value().size()), re2)); -} - -CelValue RegexMatchesPartial(Arena* arena, CelValue::StringHolder target, - CelValue::StringHolder regex) { - RE2 re2(regex.value().data()); - if (!re2.ok()) { - return CreateErrorValue(arena, "invalid_argument", - cel_base::StatusCode::kInvalidArgument); - } - return CelValue::CreateBool(RE2::PartialMatch(re2::StringPiece(target.value().data(), target.value().size()), re2)); -} - bool StringContains(Arena* arena, CelValue::StringHolder value, CelValue::StringHolder substr) { return absl::StrContains(value.value(), substr.value()); @@ -678,35 +653,10 @@ bool StringStartsWith(Arena* arena, CelValue::StringHolder value, return absl::StartsWith(value.value(), prefix.value()); } -// Creates and registers a map index function. -template -::cel_base::Status RegisterMapIndexFunction(CelFunctionRegistry* registry, - const CreateCelValue& create_cel_value, - const ToAlphaNum& to_alpha_num) { - return FunctionAdapter::CreateAndRegister( - builtin::kIndex, false, - [&create_cel_value, &to_alpha_num](Arena* arena, const CelMap* cel_map, - T key) -> CelValue { - auto maybe_value = (*cel_map)[create_cel_value(key)]; - if (!maybe_value.has_value()) { - // TODO(issues/25) Which code? - return CreateNoSuchKeyError(arena, absl::StrCat(to_alpha_num(key))); - } - return maybe_value.value(); - }, - registry); -} - -template -::cel_base::Status RegisterMapIndexFunction( - CelFunctionRegistry* registry, const CreateCelValue& create_cel_value) { - return RegisterMapIndexFunction(registry, create_cel_value, - [](T v) { return absl::StrCat(v); }); -} - } // namespace -::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry) { +::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { // logical NOT cel_base::Status status = FunctionAdapter::CreateAndRegister( builtin::kNot, false, @@ -905,21 +855,6 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry) { builtin::kSize, false, bytes_size_func, registry); if (!status.ok()) return status; - // List Index - status = FunctionAdapter::CreateAndRegister( - builtin::kIndex, false, - [](Arena* arena, const CelList* cel_list, int64_t index) -> CelValue { - if (index < 0 || index >= cel_list->size()) { - // TODO(issues/25) Which code? - return CreateErrorValue(arena, - absl::StrCat("Index error: index=", index, - " size=", cel_list->size())); - } - return (*cel_list)[index]; - }, - registry); - if (!status.ok()) return status; - // List size auto list_size_func = [](Arena* arena, const CelList* cel_list) -> int64_t { return (*cel_list).size(); @@ -934,88 +869,74 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry) { if (!status.ok()) return status; // List in operator: @in - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kIn, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kIn, false, In, - registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kIn, false, In, - registry); - if (!status.ok()) return status; - - // List in operator: _in_ (deprecated) - // Bindings preserved for backward compatibility with stored expressions. - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInDeprecated, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kInDeprecated, false, - In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kInDeprecated, false, - In, registry); - if (!status.ok()) return status; - - // List in() function (deprecated) - // Bindings preserved for backward compatibility with stored expressions. - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kInFunction, false, In, registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kInFunction, false, In, - registry); - if (!status.ok()) return status; - status = FunctionAdapter:: - CreateAndRegister(builtin::kInFunction, false, In, - registry); - if (!status.ok()) return status; - - // Map Index - status = RegisterMapIndexFunction( - registry, - [](CelValue::StringHolder v) { return CelValue::CreateString(v); }, - [](CelValue::StringHolder v) { return v.value(); }); - if (!status.ok()) return status; - - status = RegisterMapIndexFunction(registry, CelValue::CreateInt64); - if (!status.ok()) return status; - - status = RegisterMapIndexFunction(registry, CelValue::CreateUint64); - if (!status.ok()) return status; - - status = RegisterMapIndexFunction(registry, CelValue::CreateBool); - if (!status.ok()) return status; + if (options.enable_list_contains) { + status = FunctionAdapter::CreateAndRegister( + builtin::kIn, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kIn, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kIn, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kIn, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter:: + CreateAndRegister(builtin::kIn, false, In, + registry); + if (!status.ok()) return status; + status = FunctionAdapter:: + CreateAndRegister(builtin::kIn, false, In, + registry); + if (!status.ok()) return status; + + // List in operator: _in_ (deprecated) + // Bindings preserved for backward compatibility with stored expressions. + status = FunctionAdapter::CreateAndRegister( + builtin::kInDeprecated, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kInDeprecated, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kInDeprecated, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kInDeprecated, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter:: + CreateAndRegister(builtin::kInDeprecated, false, + In, registry); + if (!status.ok()) return status; + status = FunctionAdapter:: + CreateAndRegister(builtin::kInDeprecated, false, + In, registry); + if (!status.ok()) return status; + + // List in() function (deprecated) + // Bindings preserved for backward compatibility with stored expressions. + status = FunctionAdapter::CreateAndRegister( + builtin::kInFunction, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kInFunction, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kInFunction, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kInFunction, false, In, registry); + if (!status.ok()) return status; + status = FunctionAdapter:: + CreateAndRegister(builtin::kInFunction, false, + In, registry); + if (!status.ok()) return status; + status = FunctionAdapter:: + CreateAndRegister(builtin::kInFunction, false, + In, registry); + if (!status.ok()) return status; + } // Map size auto map_size_func = [](Arena* arena, const CelMap* cel_map) -> int64_t { @@ -1176,45 +1097,61 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry) { if (!status.ok()) return status; // Concat group - status = - FunctionAdapter::CreateAndRegister(builtin::kAdd, - false, - ConcatString, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter::CreateAndRegister(builtin::kAdd, - false, - ConcatBytes, - registry); - if (!status.ok()) return status; + if (options.enable_string_concat) { + status = FunctionAdapter< + CelValue::StringHolder, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kAdd, false, + ConcatString, registry); + if (!status.ok()) return status; + + status = + FunctionAdapter::CreateAndRegister(builtin::kAdd, + false, + ConcatBytes, + registry); + if (!status.ok()) return status; + } - status = - FunctionAdapter::CreateAndRegister(builtin::kAdd, false, - ConcatList, registry); - if (!status.ok()) return status; + if (options.enable_list_concat) { + status = + FunctionAdapter::CreateAndRegister(builtin::kAdd, false, + ConcatList, + registry); + if (!status.ok()) return status; + } // Global matches function. - status = FunctionAdapter:: - CreateAndRegister(builtin::kRegexMatch, false, - registry->partial_string_match() ? RegexMatchesPartial - : RegexMatchesFull, - registry); - if (!status.ok()) return status; - - // Receiver-style matches function. - status = FunctionAdapter:: - CreateAndRegister(builtin::kRegexMatch, true, - registry->partial_string_match() ? RegexMatchesPartial - : RegexMatchesFull, - registry); - if (!status.ok()) return status; + if (options.enable_regex) { + auto regex_matches = [max_size = options.regex_max_program_size]( + Arena* arena, CelValue::StringHolder target, + CelValue::StringHolder regex) -> CelValue { + RE2 re2(regex.value().data()); + if (max_size > 0 && re2.ProgramSize() > max_size) { + return CreateErrorValue(arena, "exceeded RE2 max program size", + cel_base::StatusCode::kInvalidArgument); + } + if (!re2.ok()) { + return CreateErrorValue(arena, "invalid_argument", + cel_base::StatusCode::kInvalidArgument); + } + return CelValue::CreateBool(RE2::PartialMatch(re2::StringPiece(target.value().data(), target.value().size()), re2)); + }; + + status = FunctionAdapter< + CelValue, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kRegexMatch, false, + regex_matches, registry); + if (!status.ok()) return status; + + // Receiver-style matches function. + status = FunctionAdapter< + CelValue, CelValue::StringHolder, + CelValue::StringHolder>::CreateAndRegister(builtin::kRegexMatch, true, + regex_matches, registry); + if (!status.ok()) return status; + } status = FunctionAdapter:: @@ -1491,51 +1428,53 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry) { registry); if (!status.ok()) return status; - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, int64_t value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, uint64_t value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kString, false, - [](Arena* arena, double value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, absl::StrCat(value))); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kString, false, - [](Arena* arena, - CelValue::BytesHolder value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, std::string(value.value()))); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kString, false, - [](Arena* arena, CelValue::StringHolder value) - -> CelValue::StringHolder { return value; }, - registry); - if (!status.ok()) return status; + if (options.enable_string_conversion) { + status = FunctionAdapter::CreateAndRegister( + builtin::kString, false, + [](Arena* arena, int64_t value) -> CelValue::StringHolder { + return CelValue::StringHolder( + Arena::Create(arena, absl::StrCat(value))); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kString, false, + [](Arena* arena, uint64_t value) -> CelValue::StringHolder { + return CelValue::StringHolder( + Arena::Create(arena, absl::StrCat(value))); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kString, false, + [](Arena* arena, double value) -> CelValue::StringHolder { + return CelValue::StringHolder( + Arena::Create(arena, absl::StrCat(value))); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kString, false, + [](Arena* arena, + CelValue::BytesHolder value) -> CelValue::StringHolder { + return CelValue::StringHolder( + Arena::Create(arena, std::string(value.value()))); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kString, false, + [](Arena* arena, CelValue::StringHolder value) + -> CelValue::StringHolder { return value; }, + registry); + if (!status.ok()) return status; + } return ::cel_base::OkStatus(); } diff --git a/eval/public/builtin_func_registrar.h b/eval/public/builtin_func_registrar.h index d29831417..97c603c93 100644 --- a/eval/public/builtin_func_registrar.h +++ b/eval/public/builtin_func_registrar.h @@ -2,13 +2,16 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_BUILTIN_FUNC_REGISTRAR_H_ #include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" namespace google { namespace api { namespace expr { namespace runtime { -cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry); +cel_base::Status RegisterBuiltinFunctions( + CelFunctionRegistry* registry, + const InterpreterOptions& options = InterpreterOptions()); } // namespace runtime } // namespace expr diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 5b194a295..090a699c5 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -67,7 +67,7 @@ class BuiltinsTest : public ::testing::Test { CreateCelExpressionBuilder(options); // Builtin registration. - ASSERT_TRUE(RegisterBuiltinFunctions(builder->GetRegistry()).ok()); + ASSERT_TRUE(RegisterBuiltinFunctions(builder->GetRegistry(), options).ok()); // Create CelExpression from AST (Expr object). auto cel_expression_status = builder->CreateExpression(&expr, &source_info); @@ -1315,45 +1315,17 @@ TEST_F(BuiltinsTest, TestConcatList) { } } -TEST_F(BuiltinsTest, MatchesTrue) { - std::string target = "haystack"; - std::string regex = "hay\\w{2}ack"; - std::vector args = {CelValue::CreateString(&target), - CelValue::CreateString(®ex)}; - - CelValue result_value; - ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kRegexMatch, {}, args, &result_value)); - ASSERT_TRUE(result_value.IsBool()); - EXPECT_TRUE(result_value.BoolOrDie()); -} - TEST_F(BuiltinsTest, MatchesPartialTrue) { std::string target = "haystack"; std::string regex = "\\w{2}ack"; std::vector args = {CelValue::CreateString(&target), CelValue::CreateString(®ex)}; - InterpreterOptions options; - options.partial_string_match = true; - CelValue result_value; - ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kRegexMatch, {}, args, &result_value, options)); - ASSERT_TRUE(result_value.IsBool()); - EXPECT_TRUE(result_value.BoolOrDie()); -} - -TEST_F(BuiltinsTest, MatchesFalse) { - std::string target = "haystack"; - std::string regex = "hay"; - std::vector args = {CelValue::CreateString(&target), - CelValue::CreateString(®ex)}; - CelValue result_value; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kRegexMatch, {}, args, &result_value)); ASSERT_TRUE(result_value.IsBool()); - EXPECT_FALSE(result_value.BoolOrDie()); + EXPECT_TRUE(result_value.BoolOrDie()); } TEST_F(BuiltinsTest, MatchesPartialFalse) { @@ -1362,16 +1334,14 @@ TEST_F(BuiltinsTest, MatchesPartialFalse) { std::vector args = {CelValue::CreateString(&target), CelValue::CreateString(®ex)}; - InterpreterOptions options; - options.partial_string_match = true; CelValue result_value; ASSERT_NO_FATAL_FAILURE( - PerformRun(builtin::kRegexMatch, {}, args, &result_value, options)); + PerformRun(builtin::kRegexMatch, {}, args, &result_value)); ASSERT_TRUE(result_value.IsBool()); EXPECT_FALSE(result_value.BoolOrDie()); } -TEST_F(BuiltinsTest, MatchesError) { +TEST_F(BuiltinsTest, MatchesPartialError) { std::string target = "haystack"; std::string invalid_regex = "("; std::vector args = {CelValue::CreateString(&target), @@ -1383,15 +1353,15 @@ TEST_F(BuiltinsTest, MatchesError) { EXPECT_TRUE(result_value.IsError()); } -TEST_F(BuiltinsTest, MatchesPartialError) { +TEST_F(BuiltinsTest, MatchesMaxSize) { std::string target = "haystack"; - std::string invalid_regex = "("; + std::string large_regex = "[hj][ab][yt][st][tv][ac]"; std::vector args = {CelValue::CreateString(&target), - CelValue::CreateString(&invalid_regex)}; + CelValue::CreateString(&large_regex)}; - InterpreterOptions options; - options.partial_string_match = true; CelValue result_value; + InterpreterOptions options; + options.regex_max_program_size = 1; ASSERT_NO_FATAL_FAILURE( PerformRun(builtin::kRegexMatch, {}, args, &result_value, options)); EXPECT_TRUE(result_value.IsError()); diff --git a/eval/public/cel_attribute.h b/eval/public/cel_attribute.h new file mode 100644 index 000000000..98ef5e848 --- /dev/null +++ b/eval/public/cel_attribute.h @@ -0,0 +1,226 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ + +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" +#include "eval/public/cel_value_internal.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// CelAttributeQualifier represents a segment in +// attribute resolutuion path. A segment can be qualified by values of +// following types: string/int64_t/uint64/bool. +class CelAttributeQualifier { + private: + // Helper class, used to implement CelAttributeQualifier::operator==. + class EqualVisitor { + public: + template + class NestedEqualVisitor { + public: + explicit NestedEqualVisitor(const T& arg) : arg_(arg) {} + + template + bool operator()(const U& other) const { + return false; + } + + bool operator()(const T& other) const { return other == arg_; } + + private: + const T& arg_; + }; + + explicit EqualVisitor(const CelValue& other) : other_(other) {} + + template + bool operator()(const Type& arg) { + return other_.template Visit(NestedEqualVisitor(arg)); + } + + private: + const CelValue& other_; + }; + + CelValue value_; + + explicit CelAttributeQualifier(CelValue value) : value_(value) {} + + public: + // Factory method. + static CelAttributeQualifier Create(CelValue value) { + return CelAttributeQualifier(value); + } + + // Family of Get... methods. Return values if requested type matches the + // stored one. + absl::optional GetInt64Key() const { + return (value_.IsInt64()) ? absl::optional(value_.Int64OrDie()) + : absl::nullopt; + } + + absl::optional GetUint64Key() const { + return (value_.IsUint64()) ? absl::optional(value_.Uint64OrDie()) + : absl::nullopt; + } + + absl::optional GetStringKey() const { + return (value_.IsString()) ? absl::optional(value_.StringOrDie().value()) + : absl::nullopt; + } + + absl::optional GetBoolKey() const { + return (value_.IsBool()) ? absl::optional(value_.BoolOrDie()) + : absl::nullopt; + } + + bool operator==(const CelAttributeQualifier& other) const { + return IsMatch(other.value_); + } + + bool IsMatch(const CelValue& cel_value) const { + return value_.template Visit(EqualVisitor(cel_value)); + } + + bool IsMatch(absl::string_view other_key) { + absl::optional key = GetStringKey(); + return (key.has_value() && key.value() == other_key); + } +}; + +// CelAttributeQualifierPattern matches a segment in +// attribute resolutuion path. CelAttributeQualifierPattern is capable of +// matching path elements of types string/int64_t/uint64/bool. +class CelAttributeQualifierPattern { + private: + // Qualifier value. If not set, treated as wildcard. + absl::optional value_; + + CelAttributeQualifierPattern(absl::optional value) + : value_(value) {} + + public: + // Factory method. + static CelAttributeQualifierPattern Create(CelValue value) { + return CelAttributeQualifierPattern(CelAttributeQualifier::Create(value)); + } + + static CelAttributeQualifierPattern CreateWildcard() { + return CelAttributeQualifierPattern(absl::nullopt); + } + + bool IsWildcard() const { return !value_.has_value(); } + + bool IsMatch(const CelAttributeQualifier& qualifier) const { + if (IsWildcard()) return true; + return value_.value() == qualifier; + } + + bool IsMatch(const CelValue& cel_value) const { + if (!value_.has_value()) { + switch (cel_value.type()) { + case CelValue::Type::kInt64: + case CelValue::Type::kUint64: + case CelValue::Type::kString: + case CelValue::Type::kBool: { + return true; + } + default: { + return false; + } + } + } + return value_.value().IsMatch(cel_value); + } + + bool IsMatch(absl::string_view other_key) { + if (!value_.has_value()) return true; + return value_.value().IsMatch(other_key); + } +}; + +// CelAttribute represents resolved attribute path. +class CelAttribute { + public: + CelAttribute(Expr variable, std::vector qualifier_path) + : variable_(std::move(variable)), + qualifier_path_(std::move(qualifier_path)) {} + + const Expr& variable() const { return variable_; } + + const std::vector& qualifier_path() const { + return qualifier_path_; + } + + private: + Expr variable_; + std::vector qualifier_path_; +}; + +// CelAttributePattern is a fully-qualified absolute attribute path pattern. +// Supported segments steps in the path are: +// - field selection; +// - map lookup by key; +// - list access by index. +class CelAttributePattern { + public: + // MatchType enum specifies how closely pattern is matching the attribute: + enum class MatchType { + NONE, // Pattern does not match attribute itself nor its children + PARTIAL, // Pattern matches an entity nested within attribute; + FULL // Pattern matches an attribute itself. + }; + + CelAttributePattern(std::string variable, + std::vector qualifier_path) + : variable_(std::move(variable)), + qualifier_path_(std::move(qualifier_path)) {} + + absl::string_view variable() const { return variable_; } + + const std::vector& qualifier_path() const { + return qualifier_path_; + } + + // Matches the pattern to an attribute. + // Distinguishes between no-match, partial match and full match cases. + MatchType IsMatch(const CelAttribute& attribute) const { + MatchType result = MatchType::NONE; + if (attribute.variable().ident_expr().name() != variable_) { + return result; + } + + auto max_index = qualifier_path().size(); + result = MatchType::FULL; + if (qualifier_path().size() > attribute.qualifier_path().size()) { + max_index = attribute.qualifier_path().size(); + result = MatchType::PARTIAL; + } + + for (int i = 0; i < max_index; i++) { + if (!(qualifier_path()[i].IsMatch(attribute.qualifier_path()[i]))) { + return MatchType::NONE; + } + } + return result; + } + + private: + std::string variable_; + std::vector qualifier_path_; +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_ATTRIBUTE_PATTERN_H_ diff --git a/eval/public/cel_attribute_test.cc b/eval/public/cel_attribute_test.cc new file mode 100644 index 000000000..3735399c8 --- /dev/null +++ b/eval/public/cel_attribute_test.cc @@ -0,0 +1,282 @@ +#include "eval/public/cel_attribute.h" + +#include "google/protobuf/arena.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/string_view.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +using ::google::protobuf::Duration; +using ::google::protobuf::Timestamp; + +using testing::Eq; + +namespace { + +class DummyMap : public CelMap { + public: + absl::optional operator[](CelValue value) const override { + return CelValue::CreateNull(); + } + const CelList* ListKeys() const override { return nullptr; } + + int size() const override { return 0; } +}; + +class DummyList : public CelList { + public: + int size() const override { return 0; } + + CelValue operator[](int index) const override { + return CelValue::CreateNull(); + } +}; + +TEST(CelAttributeQualifierTest, TestBoolAccess) { + auto qualifier = CelAttributeQualifier::Create(CelValue::CreateBool(true)); + + EXPECT_FALSE(qualifier.GetStringKey().has_value()); + EXPECT_FALSE(qualifier.GetInt64Key().has_value()); + EXPECT_FALSE(qualifier.GetUint64Key().has_value()); + EXPECT_TRUE(qualifier.GetBoolKey().has_value()); + EXPECT_THAT(qualifier.GetBoolKey().value(), Eq(true)); +} + +TEST(CelAttributeQualifierTest, TestInt64Access) { + auto qualifier = CelAttributeQualifier::Create(CelValue::CreateInt64(1)); + + EXPECT_FALSE(qualifier.GetBoolKey().has_value()); + EXPECT_FALSE(qualifier.GetStringKey().has_value()); + EXPECT_FALSE(qualifier.GetUint64Key().has_value()); + + EXPECT_TRUE(qualifier.GetInt64Key().has_value()); + EXPECT_THAT(qualifier.GetInt64Key().value(), Eq(1)); +} + +TEST(CelAttributeQualifierTest, TestUint64Access) { + auto qualifier = CelAttributeQualifier::Create(CelValue::CreateUint64(1)); + + EXPECT_FALSE(qualifier.GetBoolKey().has_value()); + EXPECT_FALSE(qualifier.GetStringKey().has_value()); + EXPECT_FALSE(qualifier.GetInt64Key().has_value()); + + EXPECT_TRUE(qualifier.GetUint64Key().has_value()); + EXPECT_THAT(qualifier.GetUint64Key().value(), Eq(1)); +} + +TEST(CelAttributeQualifierTest, TestStringAccess) { + const std::string test = "test"; + auto qualifier = CelAttributeQualifier::Create(CelValue::CreateString(&test)); + + EXPECT_FALSE(qualifier.GetBoolKey().has_value()); + EXPECT_FALSE(qualifier.GetInt64Key().has_value()); + EXPECT_FALSE(qualifier.GetUint64Key().has_value()); + + EXPECT_TRUE(qualifier.GetStringKey().has_value()); + EXPECT_THAT(qualifier.GetStringKey().value(), Eq("test")); +} + +void TestAllInequalities(const CelAttributeQualifier& qualifier) { + EXPECT_FALSE(qualifier == + CelAttributeQualifier::Create(CelValue::CreateBool(false))); + EXPECT_FALSE(qualifier == + CelAttributeQualifier::Create(CelValue::CreateInt64(0))); + EXPECT_FALSE(qualifier == + CelAttributeQualifier::Create(CelValue::CreateUint64(0))); + const std::string test = "Those are not the droids you are looking for."; + EXPECT_FALSE(qualifier == + CelAttributeQualifier::Create(CelValue::CreateString(&test))); +} + +TEST(CelAttributeQualifierTest, TestBoolComparison) { + auto qualifier = CelAttributeQualifier::Create(CelValue::CreateBool(true)); + TestAllInequalities(qualifier); + EXPECT_TRUE(qualifier == + CelAttributeQualifier::Create(CelValue::CreateBool(true))); +} + +TEST(CelAttributeQualifierTest, TestInt64Comparison) { + auto qualifier = CelAttributeQualifier::Create(CelValue::CreateInt64(true)); + TestAllInequalities(qualifier); + EXPECT_TRUE(qualifier == + CelAttributeQualifier::Create(CelValue::CreateInt64(true))); +} + +TEST(CelAttributeQualifierTest, TestUint64Comparison) { + auto qualifier = CelAttributeQualifier::Create(CelValue::CreateUint64(true)); + TestAllInequalities(qualifier); + EXPECT_TRUE(qualifier == + CelAttributeQualifier::Create(CelValue::CreateUint64(true))); +} + +TEST(CelAttributeQualifierTest, TestStringComparison) { + const std::string kTest = "test"; + auto qualifier = + CelAttributeQualifier::Create(CelValue::CreateString(&kTest)); + TestAllInequalities(qualifier); + EXPECT_TRUE(qualifier == + CelAttributeQualifier::Create(CelValue::CreateString(&kTest))); +} + +void TestAllCelValueMismatches(const CelAttributeQualifierPattern& qualifier) { + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateNull())); + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateBool(false))); + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateInt64(0))); + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateUint64(0))); + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateDouble(0.))); + + const std::string kStr = "Those are not the droids you are looking for."; + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateString(&kStr))); + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateBytes(&kStr))); + + Duration msg_duration; + msg_duration.set_seconds(0); + msg_duration.set_nanos(0); + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateDuration(&msg_duration))); + + Timestamp msg_timestamp; + msg_timestamp.set_seconds(0); + msg_timestamp.set_nanos(0); + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateTimestamp(&msg_timestamp))); + + DummyList dummy_list; + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateList(&dummy_list))); + + DummyMap dummy_map; + EXPECT_FALSE(qualifier.IsMatch(CelValue::CreateMap(&dummy_map))); + + google::protobuf::Arena arena; + EXPECT_FALSE(qualifier.IsMatch(CreateErrorValue(&arena, kStr))); +} + +void TestAllQualifierMismatches(const CelAttributeQualifierPattern& qualifier) { + const std::string test = "Those are not the droids you are looking for."; + EXPECT_FALSE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateBool(false)))); + EXPECT_FALSE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateInt64(0)))); + EXPECT_FALSE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateUint64(0)))); + EXPECT_FALSE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateString(&test)))); +} + +TEST(CelAttributeQualifierPatternTest, TestCelValueBoolMatch) { + auto qualifier = + CelAttributeQualifierPattern::Create(CelValue::CreateBool(true)); + + TestAllCelValueMismatches(qualifier); + + CelValue value_match = CelValue::CreateBool(true); + + EXPECT_TRUE(qualifier.IsMatch(value_match)); +} + +TEST(CelAttributeQualifierPatternTest, TestCelValueInt64Match) { + auto qualifier = + CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1)); + + TestAllCelValueMismatches(qualifier); + + CelValue value_match = CelValue::CreateInt64(1); + + EXPECT_TRUE(qualifier.IsMatch(value_match)); +} + +TEST(CelAttributeQualifierPatternTest, TestCelValueUint64Match) { + auto qualifier = + CelAttributeQualifierPattern::Create(CelValue::CreateUint64(1)); + + TestAllCelValueMismatches(qualifier); + + CelValue value_match = CelValue::CreateUint64(1); + + EXPECT_TRUE(qualifier.IsMatch(value_match)); +} + +TEST(CelAttributeQualifierPatternTest, TestCelValueStringMatch) { + std::string kTest = "test"; + auto qualifier = + CelAttributeQualifierPattern::Create(CelValue::CreateString(&kTest)); + + TestAllCelValueMismatches(qualifier); + + CelValue value_match = CelValue::CreateString(&kTest); + + EXPECT_TRUE(qualifier.IsMatch(value_match)); +} + +TEST(CelAttributeQualifierPatternTest, TestQualifierBoolMatch) { + auto qualifier = + CelAttributeQualifierPattern::Create(CelValue::CreateBool(true)); + + TestAllQualifierMismatches(qualifier); + + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateBool(true)))); +} + +TEST(CelAttributeQualifierPatternTest, TestQualifierInt64Match) { + auto qualifier = + CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1)); + + TestAllQualifierMismatches(qualifier); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateInt64(1)))); +} + +TEST(CelAttributeQualifierPatternTest, TestQualifierUint64Match) { + auto qualifier = + CelAttributeQualifierPattern::Create(CelValue::CreateUint64(1)); + + TestAllQualifierMismatches(qualifier); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateUint64(1)))); +} + +TEST(CelAttributeQualifierPatternTest, TestQualifierStringMatch) { + const std::string test = "test"; + auto qualifier = + CelAttributeQualifierPattern::Create(CelValue::CreateString(&test)); + + TestAllQualifierMismatches(qualifier); + + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateString(&test)))); +} + +TEST(CelAttributeQualifierPatternTest, TestQualifierWildcardMatch) { + auto qualifier = CelAttributeQualifierPattern::CreateWildcard(); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateBool(false)))); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateBool(true)))); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateInt64(0)))); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateInt64(1)))); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateUint64(0)))); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateUint64(1)))); + + const std::string kTest1 = "test1"; + const std::string kTest2 = "test2"; + + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateString(&kTest1)))); + EXPECT_TRUE(qualifier.IsMatch( + CelAttributeQualifier::Create(CelValue::CreateString(&kTest2)))); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 79b008f7a..21be23507 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -10,10 +10,12 @@ std::unique_ptr CreateCelExpressionBuilder( const InterpreterOptions& options) { auto builder = absl::make_unique(); builder->set_shortcircuiting(options.short_circuiting); - builder->GetRegistry()->set_partial_string_match( - options.partial_string_match); builder->set_constant_folding(options.constant_folding, options.constant_arena); + builder->set_enable_comprehension(options.enable_comprehension); + builder->set_comprehension_max_iterations( + options.comprehension_max_iterations); + return std::move(builder); } diff --git a/eval/public/cel_expr_builder_factory.h b/eval/public/cel_expr_builder_factory.h index be16be9d3..f3f08d991 100644 --- a/eval/public/cel_expr_builder_factory.h +++ b/eval/public/cel_expr_builder_factory.h @@ -2,35 +2,13 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPR_BUILDER_FACTORY_H_ #include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" namespace google { namespace api { namespace expr { namespace runtime { -// Interpreter options for controlling evaluation and builtin functions. -struct InterpreterOptions { - InterpreterOptions() - : short_circuiting(true), - partial_string_match(false), - constant_folding(false), - constant_arena(nullptr) {} - - // Enable short-circuiting of the logical operator evaluation. If enabled, - // AND, OR, and TERNARY do not evaluate the entire expression once the the - // resulting value is known from the left-hand side. - bool short_circuiting = true; - - // Indicate whether to use partial or full string regex matching. - // Should be enabled to conform with the CEL specification. - bool partial_string_match = false; - - // Enable constant folding during the expression creation. If enabled, - // an arena must be provided for constant generation. - bool constant_folding = false; - google::protobuf::Arena* constant_arena; -}; - // Factory creates CelExpressionBuilder implementation for public use. std::unique_ptr CreateCelExpressionBuilder( const InterpreterOptions& options = InterpreterOptions()); diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index 84c08a4da..918115e71 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -3,6 +3,7 @@ #include "absl/container/node_hash_map.h" #include "absl/types/span.h" +#include "eval/public/cel_options.h" #include "eval/public/cel_value.h" namespace google { @@ -64,6 +65,9 @@ class CelFunction { // CelFunction descriptor const Descriptor& descriptor() const { return descriptor_; } + // Configuration extension to customize the overload behavior. + virtual void Configure(const InterpreterOptions& options) {} + private: Descriptor descriptor_; }; @@ -73,7 +77,7 @@ class CelFunction { // CelExpression objects from Expr ASTs. class CelFunctionRegistry { public: - CelFunctionRegistry() : partial_string_match_(false) {} + CelFunctionRegistry() {} ~CelFunctionRegistry() {} @@ -98,20 +102,10 @@ class CelFunctionRegistry { absl::node_hash_map> ListFunctions() const; - // Select partial or full regex match for match() built-in function. - void set_partial_string_match(bool enabled) { - partial_string_match_ = enabled; - } - - // Use partial regex match for match() built-in function. - bool partial_string_match() { return partial_string_match_; } - private: using Overloads = std::vector>; absl::node_hash_map functions_; - - bool partial_string_match_; }; } // namespace runtime diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h new file mode 100644 index 000000000..46cbfb8d2 --- /dev/null +++ b/eval/public/cel_options.h @@ -0,0 +1,60 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ + +#include "google/protobuf/arena.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Interpreter options for controlling evaluation and builtin functions. +struct InterpreterOptions { + // Enable short-circuiting of the logical operator evaluation. If enabled, + // AND, OR, and TERNARY do not evaluate the entire expression once the the + // resulting value is known from the left-hand side. + bool short_circuiting = true; + + // DEPRECATED. This option has no effect. + bool partial_string_match = true; + + // Enable constant folding during the expression creation. If enabled, + // an arena must be provided for constant generation. + bool constant_folding = false; + google::protobuf::Arena* constant_arena = nullptr; + + // Enable comprehension expressions (e.g. exists, all) + bool enable_comprehension = true; + + // Set maximum number of iterations in the comprehension expressions if + // comprehensions are enabled. The limit applies globally per an evaluation, + // including the nested loops as well. Use value 0 to disable the upper bound. + int comprehension_max_iterations = 0; + + // Enable RE2 match() overload. + bool enable_regex = true; + + // Set maximum program size for RE2 regex if regex overload is enabled. + // Evaluates to an error if a regex exceeds it. Use value 0 to disable the + // upper bound. + int regex_max_program_size = 0; + + // Enable string() overloads. + bool enable_string_conversion = true; + + // Enable string concatenation overload. + bool enable_string_concat = true; + + // Enable list concatenation overload. + bool enable_list_concat = true; + + // Enable list membership overload. + bool enable_list_contains = true; +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_ diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc new file mode 100644 index 000000000..b557d6804 --- /dev/null +++ b/eval/public/value_export_util.cc @@ -0,0 +1,146 @@ +#include "eval/public/value_export_util.h" + +#include "google/protobuf/util/json_util.h" +#include "google/protobuf/util/time_util.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "base/canonical_errors.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +using google::protobuf::Duration; +using google::protobuf::ListValue; +using google::protobuf::Struct; +using google::protobuf::Timestamp; +using google::protobuf::Value; +using google::protobuf::FieldDescriptor; +using google::protobuf::Message; +using google::protobuf::util::TimeUtil; + +cel_base::Status KeyAsString(const CelValue& value, std::string* key) { + switch (value.type()) { + case CelValue::Type::kInt64: { + *key = absl::StrCat(value.Int64OrDie()); + break; + } + case CelValue::Type::kUint64: { + *key = absl::StrCat(value.Uint64OrDie()); + break; + } + case CelValue::Type::kString: { + key->assign(value.StringOrDie().value().data(), + value.StringOrDie().value().size()); + break; + } + default: { return cel_base::InvalidArgumentError("Unsupported map type"); } + } + return cel_base::OkStatus(); +} + +// Export content of CelValue as google.protobuf.Value. +cel_base::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { + if (in_value.IsNull()) { + out_value->set_null_value(google::protobuf::NULL_VALUE); + return cel_base::OkStatus(); + } + switch (in_value.type()) { + case CelValue::Type::kBool: { + out_value->set_bool_value(in_value.BoolOrDie()); + break; + } + case CelValue::Type::kInt64: { + out_value->set_number_value(static_cast(in_value.Int64OrDie())); + break; + } + case CelValue::Type::kUint64: { + out_value->set_number_value(static_cast(in_value.Uint64OrDie())); + break; + } + case CelValue::Type::kDouble: { + out_value->set_number_value(in_value.DoubleOrDie()); + break; + } + case CelValue::Type::kString: { + auto value = in_value.StringOrDie().value(); + out_value->set_string_value(value.data(), value.size()); + break; + } + case CelValue::Type::kBytes: { + absl::Base64Escape(in_value.BytesOrDie().value(), + out_value->mutable_string_value()); + break; + } + case CelValue::Type::kDuration: { + Duration duration; + expr::internal::EncodeDuration(in_value.DurationOrDie(), &duration); + out_value->set_string_value(TimeUtil::ToString(duration)); + break; + } + case CelValue::Type::kTimestamp: { + Timestamp timestamp; + expr::internal::EncodeTime(in_value.TimestampOrDie(), ×tamp); + out_value->set_string_value(TimeUtil::ToString(timestamp)); + break; + } + case CelValue::Type::kMessage: { + google::protobuf::util::JsonPrintOptions json_options; + json_options.preserve_proto_field_names = true; + std::string json; + auto status = google::protobuf::util::MessageToJsonString(*in_value.MessageOrDie(), + &json, json_options); + if (!status.ok()) { + return cel_base::InternalError(status.ToString()); + } + google::protobuf::util::JsonParseOptions json_parse_options; + status = google::protobuf::util::JsonStringToMessage(json, out_value, + json_parse_options); + if (!status.ok()) { + return cel_base::InternalError(status.ToString()); + } + break; + } + case CelValue::Type::kList: { + const CelList* cel_list = in_value.ListOrDie(); + auto out_values = out_value->mutable_list_value(); + for (int i = 0; i < cel_list->size(); i++) { + auto status = + ExportAsProtoValue((*cel_list)[i], out_values->add_values()); + if (!status.ok()) { + return status; + } + } + break; + } + case CelValue::Type::kMap: { + const CelMap* cel_map = in_value.MapOrDie(); + auto keys_list = cel_map->ListKeys(); + auto out_values = out_value->mutable_struct_value()->mutable_fields(); + for (int i = 0; i < keys_list->size(); i++) { + std::string key; + CelValue map_key = (*keys_list)[i]; + auto status = KeyAsString(map_key, &key); + if (!status.ok()) { + return status; + } + auto map_value_ref = (*cel_map)[map_key]; + CelValue map_value = + (map_value_ref) ? map_value_ref.value() : CelValue(); + status = ExportAsProtoValue(map_value, &((*out_values)[key])); + if (!status.ok()) { + return status; + } + } + break; + } + default: { return cel_base::InvalidArgumentError("Unsupported value type"); } + } + return cel_base::OkStatus(); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/value_export_util.h b/eval/public/value_export_util.h new file mode 100644 index 000000000..39018b3d7 --- /dev/null +++ b/eval/public/value_export_util.h @@ -0,0 +1,27 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ + +#include "eval/public/cel_value.h" + +#include "google/protobuf/struct.pb.h" +#include "base/status.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Exports content of CelValue as google.protobuf.Value. +// Current limitations: +// - exports integer values as doubles (Value.number_value); +// - exports integer keys in maps as strings; +// - handles Duration and Timestamp as generic messages. +cel_base::Status ExportAsProtoValue(const CelValue &in_value, + google::protobuf::Value *out_value); + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ diff --git a/eval/public/value_export_util_test.cc b/eval/public/value_export_util_test.cc new file mode 100644 index 000000000..e988ebe2d --- /dev/null +++ b/eval/public/value_export_util_test.cc @@ -0,0 +1,343 @@ +#include "eval/public/value_export_util.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "eval/eval/container_backed_list_impl.h" +#include "eval/eval/container_backed_map_impl.h" +#include "eval/testutil/test_message.pb.h" +#include "testutil/util.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::protobuf::Duration; +using google::protobuf::ListValue; +using google::protobuf::Struct; +using google::protobuf::Timestamp; +using google::protobuf::Value; +using google::protobuf::Arena; + +TEST(ValueExportUtilTest, ConvertBoolValue) { + CelValue cel_value = CelValue::CreateBool(true); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kBoolValue); + EXPECT_EQ(value.bool_value(), true); +} + +TEST(ValueExportUtilTest, ConvertInt64Value) { + CelValue cel_value = CelValue::CreateInt64(-1); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); + EXPECT_DOUBLE_EQ(value.number_value(), -1); +} + +TEST(ValueExportUtilTest, ConvertUint64Value) { + CelValue cel_value = CelValue::CreateUint64(1); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); + EXPECT_DOUBLE_EQ(value.number_value(), 1); +} + +TEST(ValueExportUtilTest, ConvertDoubleValue) { + CelValue cel_value = CelValue::CreateDouble(1.3); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); + EXPECT_DOUBLE_EQ(value.number_value(), 1.3); +} + +TEST(ValueExportUtilTest, ConvertStringValue) { + std::string test = "test"; + CelValue cel_value = CelValue::CreateString(&test); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); + EXPECT_EQ(value.string_value(), "test"); +} + +TEST(ValueExportUtilTest, ConvertBytesValue) { + std::string test = "test"; + CelValue cel_value = CelValue::CreateBytes(&test); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); + // Check that the result is BASE64 encoded. + EXPECT_EQ(value.string_value(), "dGVzdA=="); +} + +TEST(ValueExportUtilTest, ConvertDurationValue) { + Duration duration; + duration.set_seconds(2); + duration.set_nanos(3); + CelValue cel_value = CelValue::CreateDuration(&duration); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); + EXPECT_EQ(value.string_value(), "2.000000003s"); +} + +TEST(ValueExportUtilTest, ConvertTimestampValue) { + Timestamp timestamp; + timestamp.set_seconds(1000000000); + timestamp.set_nanos(3); + CelValue cel_value = CelValue::CreateTimestamp(×tamp); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); + EXPECT_EQ(value.string_value(), "2001-09-09T01:46:40.000000003Z"); +} + +TEST(ValueExportUtilTest, ConvertStructMessage) { + Struct struct_msg; + (*struct_msg.mutable_fields())["string_value"].set_string_value("test"); + Arena arena; + CelValue cel_value = CelValue::CreateMessage(&struct_msg, &arena); + Value value; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + EXPECT_THAT(value.struct_value(), testutil::EqualsProto(struct_msg)); +} + +TEST(ValueExportUtilTest, ConvertValueMessage) { + Value value_in; + // key-based access forces value to be a struct. + (*value_in.mutable_struct_value()->mutable_fields())["boolean_value"] + .set_bool_value(true); + Arena arena; + CelValue cel_value = CelValue::CreateMessage(&value_in, &arena); + Value value_out; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value_out).ok()); + EXPECT_THAT(value_in, testutil::EqualsProto(value_out)); +} + +TEST(ValueExportUtilTest, ConvertListValueMessage) { + ListValue list_value; + list_value.add_values()->set_string_value("test"); + list_value.add_values()->set_bool_value(true); + Arena arena; + CelValue cel_value = CelValue::CreateMessage(&list_value, &arena); + Value value_out; + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value_out).ok()); + EXPECT_THAT(list_value, testutil::EqualsProto(value_out.list_value())); +} + +TEST(ValueExportUtilTest, ConvertRepeatedBoolValue) { + Arena arena; + Value value; + + TestMessage *msg = Arena::CreateMessage(&arena); + msg->add_bool_list(true); + msg->add_bool_list(false); + CelValue cel_value = CelValue::CreateMessage(msg, &arena); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + Value list_value = value.struct_value().fields().at("bool_list"); + + EXPECT_TRUE(list_value.has_list_value()); + EXPECT_EQ(list_value.list_value().values(0).bool_value(), true); + EXPECT_EQ(list_value.list_value().values(1).bool_value(), false); +} + +TEST(ValueExportUtilTest, ConvertRepeatedInt32Value) { + Arena arena; + Value value; + + TestMessage *msg = Arena::CreateMessage(&arena); + msg->add_int32_list(2); + msg->add_int32_list(3); + CelValue cel_value = CelValue::CreateMessage(msg, &arena); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + Value list_value = value.struct_value().fields().at("int32_list"); + + EXPECT_TRUE(list_value.has_list_value()); + EXPECT_DOUBLE_EQ(list_value.list_value().values(0).number_value(), 2); + EXPECT_DOUBLE_EQ(list_value.list_value().values(1).number_value(), 3); +} + +TEST(ValueExportUtilTest, ConvertRepeatedInt64Value) { + Arena arena; + Value value; + + TestMessage *msg = Arena::CreateMessage(&arena); + msg->add_int64_list(2); + msg->add_int64_list(3); + CelValue cel_value = CelValue::CreateMessage(msg, &arena); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + Value list_value = value.struct_value().fields().at("int64_list"); + + EXPECT_TRUE(list_value.has_list_value()); + EXPECT_EQ(list_value.list_value().values(0).string_value(), "2"); + EXPECT_EQ(list_value.list_value().values(1).string_value(), "3"); +} + +TEST(ValueExportUtilTest, ConvertRepeatedUint64Value) { + Arena arena; + Value value; + + TestMessage *msg = Arena::CreateMessage(&arena); + msg->add_uint64_list(2); + msg->add_uint64_list(3); + CelValue cel_value = CelValue::CreateMessage(msg, &arena); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + Value list_value = value.struct_value().fields().at("uint64_list"); + + EXPECT_TRUE(list_value.has_list_value()); + EXPECT_EQ(list_value.list_value().values(0).string_value(), "2"); + EXPECT_EQ(list_value.list_value().values(1).string_value(), "3"); +} + +TEST(ValueExportUtilTest, ConvertRepeatedDoubleValue) { + Arena arena; + Value value; + + TestMessage *msg = Arena::CreateMessage(&arena); + msg->add_double_list(2); + msg->add_double_list(3); + CelValue cel_value = CelValue::CreateMessage(msg, &arena); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + Value list_value = value.struct_value().fields().at("double_list"); + + EXPECT_TRUE(list_value.has_list_value()); + EXPECT_DOUBLE_EQ(list_value.list_value().values(0).number_value(), 2); + EXPECT_DOUBLE_EQ(list_value.list_value().values(1).number_value(), 3); +} + +TEST(ValueExportUtilTest, ConvertRepeatedStringValue) { + Arena arena; + Value value; + + TestMessage *msg = Arena::CreateMessage(&arena); + msg->add_string_list("test1"); + msg->add_string_list("test2"); + CelValue cel_value = CelValue::CreateMessage(msg, &arena); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + Value list_value = value.struct_value().fields().at("string_list"); + + EXPECT_TRUE(list_value.has_list_value()); + EXPECT_EQ(list_value.list_value().values(0).string_value(), "test1"); + EXPECT_EQ(list_value.list_value().values(1).string_value(), "test2"); +} + +TEST(ValueExportUtilTest, ConvertRepeatedBytesValue) { + Arena arena; + Value value; + + TestMessage *msg = Arena::CreateMessage(&arena); + msg->add_bytes_list("test1"); + msg->add_bytes_list("test2"); + CelValue cel_value = CelValue::CreateMessage(msg, &arena); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + Value list_value = value.struct_value().fields().at("bytes_list"); + + EXPECT_TRUE(list_value.has_list_value()); + EXPECT_EQ(list_value.list_value().values(0).string_value(), "dGVzdDE="); + EXPECT_EQ(list_value.list_value().values(1).string_value(), "dGVzdDI="); +} + +TEST(ValueExportUtilTest, ConvertCelList) { + Arena arena; + Value value; + + std::vector values; + values.push_back(CelValue::CreateInt64(2)); + values.push_back(CelValue::CreateInt64(3)); + CelList *cel_list = Arena::Create(&arena, values); + CelValue cel_value = CelValue::CreateList(cel_list); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kListValue); + + EXPECT_DOUBLE_EQ(value.list_value().values(0).number_value(), 2); + EXPECT_DOUBLE_EQ(value.list_value().values(1).number_value(), 3); +} + +TEST(ValueExportUtilTest, ConvertCelMapWithStringKey) { + Value value; + std::vector> map_entries; + + std::string key1 = "key1"; + std::string key2 = "key2"; + std::string value1 = "value1"; + std::string value2 = "value2"; + + map_entries.push_back( + {CelValue::CreateString(&key1), CelValue::CreateString(&value1)}); + map_entries.push_back( + {CelValue::CreateString(&key2), CelValue::CreateString(&value2)}); + + auto cel_map = CreateContainerBackedMap( + absl::Span>(map_entries)); + CelValue cel_value = CelValue::CreateMap(cel_map.get()); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + const auto &fields = value.struct_value().fields(); + + EXPECT_EQ(fields.at(key1).string_value(), value1); + EXPECT_EQ(fields.at(key2).string_value(), value2); +} + +TEST(ValueExportUtilTest, ConvertCelMapWithInt64Key) { + Value value; + std::vector> map_entries; + + int key1 = -1; + int key2 = 2; + std::string value1 = "value1"; + std::string value2 = "value2"; + + map_entries.push_back( + {CelValue::CreateInt64(key1), CelValue::CreateString(&value1)}); + map_entries.push_back( + {CelValue::CreateInt64(key2), CelValue::CreateString(&value2)}); + + auto cel_map = CreateContainerBackedMap( + absl::Span>(map_entries)); + CelValue cel_value = CelValue::CreateMap(cel_map.get()); + + EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); + + const auto &fields = value.struct_value().fields(); + + EXPECT_EQ(fields.at(absl::StrCat(key1)).string_value(), value1); + EXPECT_EQ(fields.at(absl::StrCat(key2)).string_value(), value2); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/testutil/BUILD b/eval/testutil/BUILD index 4e2f179fb..5753c6622 100644 --- a/eval/testutil/BUILD +++ b/eval/testutil/BUILD @@ -11,6 +11,11 @@ proto_library( "test_message.proto", ], deps = [ + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + "@com_google_protobuf//:wrappers_proto", ], ) diff --git a/eval/testutil/test_message.proto b/eval/testutil/test_message.proto index bd89c01da..22fb71c70 100644 --- a/eval/testutil/test_message.proto +++ b/eval/testutil/test_message.proto @@ -2,6 +2,12 @@ syntax = "proto3"; package google.api.expr.runtime; +import "google/protobuf/any.proto"; +import "google/protobuf/duration.proto"; +import "google/protobuf/struct.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; + option cc_enable_arenas = true; enum TestEnum { @@ -11,8 +17,6 @@ enum TestEnum { TEST_ENUM_3 = 30; } -// Message representing errors -// during CEL evaluation. message TestMessage { int32 int32_value = 1; int64 int64_value = 2; @@ -60,4 +64,20 @@ message TestMessage { map int64_int32_map = 201; map uint64_int32_map = 202; map string_int32_map = 203; + + // Well-known types. + google.protobuf.Any any_value = 300; + google.protobuf.Duration duration_value = 301; + google.protobuf.Timestamp timestamp_value = 302; + google.protobuf.Struct struct_value = 303; + google.protobuf.Value value_value = 304; + google.protobuf.Int64Value int64_wrapper_value = 305; + google.protobuf.Int32Value int32_wrapper_value = 306; + google.protobuf.DoubleValue double_wrapper_value = 307; + google.protobuf.FloatValue float_wrapper_value = 308; + google.protobuf.UInt64Value uint64_wrapper_value = 309; + google.protobuf.UInt32Value uint32_wrapper_value = 310; + google.protobuf.StringValue string_wrapper_value = 311; + google.protobuf.BoolValue bool_wrapper_value = 312; + google.protobuf.BytesValue bytes_wrapper_value = 313; }