diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 3bf5ebb9b90..201b37eca12 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -402,6 +402,7 @@ if(ARROW_COMPUTE) compute/exec/order_by_impl.cc compute/exec/partition_util.cc compute/exec/options.cc + compute/exec/parsing.cc compute/exec/project_node.cc compute/exec/query_context.cc compute/exec/sink_node.cc diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index c9c7b0e605f..fa6600ed92f 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -128,6 +129,22 @@ class ARROW_EXPORT Expression { explicit Expression(Datum literal); explicit Expression(Parameter parameter); + /* + Grammar: + Expr -> FieldRef | Literal | Call + + FieldRef -> Field | Field FieldRef + Field -> . Name | [ Number ] + + Literal -> $ TypeName : Value + + Call -> Name ( ExprList ) + ExprList -> Expr | Expr , ExprList + + Name, TypeName, Value, and Number take the obvious values + */ + static Result FromString(std::string_view expr); + private: using Impl = std::variant; std::shared_ptr impl_; diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 6dc48b3be4e..6f7f038310f 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -28,6 +28,7 @@ #include #include "arrow/compute/exec/expression_internal.h" +#include "arrow/compute/exec/test_util.h" #include "arrow/compute/function_internal.h" #include "arrow/compute/registry.h" #include "arrow/testing/gtest_util.h" @@ -1562,6 +1563,65 @@ TEST(Expression, SerializationRoundTrips) { equal(field_ref("beta"), literal(3.25f))})); } +TEST(Expression, ParseBasic) { + const char* expr_str = "add($int32:1, .i32_0)"; + ASSERT_OK_AND_ASSIGN(Expression expr, Expression::FromString(expr_str)); + ExecBatch batch = ExecBatchFromJSON({int32(), int32()}, "[[1, 2], [1, 2]]"); + std::shared_ptr sch = + schema({field("i32_0", int32()), field("i32_1", int32())}); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*sch.get())); + ASSERT_OK_AND_ASSIGN(Datum result, ExecuteScalarExpression(expr, batch)); + const int32_t* vals = + reinterpret_cast(result.array()->buffers[1]->data()); + ASSERT_EQ(result.array()->length, 2); + ASSERT_EQ(vals[0], 2); + ASSERT_EQ(vals[1], 2); +} + +TEST(Expression, ParseComplexExpr) { + const char* expr_str = "add(multiply(.m, .x), .b)"; + ASSERT_OK_AND_ASSIGN(Expression expr, Expression::FromString(expr_str)); + ExecBatch batch = + ExecBatchFromJSON({int32(), int32(), int32()}, "[[3, 1, 1], [1, 1, 0], [3, 3, 1]]"); + std::shared_ptr sch = + schema({field("m", int32()), field("x", int32()), field("b", int32())}); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*sch.get())); + ASSERT_OK_AND_ASSIGN(Datum result, ExecuteScalarExpression(expr, batch)); + const int32_t* vals = + reinterpret_cast(result.array()->buffers[1]->data()); + ASSERT_EQ(result.array()->length, 3); + ASSERT_EQ(vals[0], 4); + ASSERT_EQ(vals[1], 1); + ASSERT_EQ(vals[2], 10); +} + +TEST(Expression, ParseComplexScalar) { + const char* expr_str = "add($duration(MILLI):10, $duration(MILLI):20)"; + ASSERT_OK_AND_ASSIGN(Expression expr, Expression::FromString(expr_str)); + std::shared_ptr empty_schema = schema({}); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*empty_schema.get())); + ASSERT_OK_AND_ASSIGN(Datum result, ExecuteScalarExpression(expr, {})); + DurationScalar expected(30, TimeUnit::MILLI); + ASSERT_TRUE(result.scalar()->Equals(expected)); +} + +TEST(Expression, ParseEscaped) { + const char* expr_str = "$utf8:hello\\, \\(world\\)"; + ASSERT_OK_AND_ASSIGN(Expression expr, Expression::FromString(expr_str)); + std::shared_ptr empty_schema = schema({}); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*empty_schema.get())); + ASSERT_OK_AND_ASSIGN(Datum result, ExecuteScalarExpression(expr, {})); + StringScalar expected("hello, (world)"); + ASSERT_TRUE(result.scalar()->Equals(expected)); +} + +TEST(Expression, ParseErrorMessage) { + const char* expr_str = "$asdfasdf:horoshaya_kartoshka"; + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + testing::HasSubstr("...asdf:horoshaya_karto..."), + Expression::FromString(expr_str)); +} + TEST(Projection, AugmentWithNull) { // NB: input contains *no columns* except i32 auto input = ArrayFromJSON(struct_({kBoringSchema->GetFieldByName("i32")}), diff --git a/cpp/src/arrow/compute/exec/parsing.cc b/cpp/src/arrow/compute/exec/parsing.cc new file mode 100644 index 00000000000..3f44e4071ae --- /dev/null +++ b/cpp/src/arrow/compute/exec/parsing.cc @@ -0,0 +1,365 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "arrow/compute/exec/expression.h" +#include "arrow/type_fwd.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace compute { +static void ConsumeWhitespace(std::string_view* view) { + constexpr const char* kWhitespaces = " \f\n\r\t\v"; + size_t first_nonwhitespace = view->find_first_not_of(kWhitespaces); + view->remove_prefix(first_nonwhitespace); +} + +static std::string ProcessEscapes(std::string_view view) { + std::string processed = ""; + bool prev_was_backslash = false; + for (size_t i = 0; i < view.size(); i++) { + if (prev_was_backslash) { + processed += view[i]; + prev_was_backslash = false; + } else { + if (view[i] == '\\') + prev_was_backslash = true; + else + processed += view[i]; + } + } + return processed; +} + +static std::string_view ExtractUntil(std::string_view* view, + const std::string_view separators) { + size_t separator = 0; + do { + std::string_view after_last_separator = view->substr(separator); + size_t next_separator = after_last_separator.find_first_of(separators); + if (next_separator == after_last_separator.npos) + next_separator = after_last_separator.size(); + separator += next_separator; // after_last_separator is offset by separator + if (separator > 0 && view->at(separator - 1) == '\\') + separator++; + else + break; + } while (separator < view->size()); + std::string_view prefix = view->substr(0, separator); + view->remove_prefix(separator); + return prefix; +} + +static std::string_view TrimUntilNextSeparator(std::string_view* view) { + constexpr const char* separators = "\f\n\r\t\v),"; + return ExtractUntil(view, separators); +} + +static std::string_view ExtractArgument(std::string_view* view) { + constexpr const char* separators = ",)"; + return ExtractUntil(view, separators); +} + +static const std::unordered_map> + kNameToSimpleType = { + {"null", null()}, + {"boolean", boolean()}, + {"int8", int8()}, + {"int16", int16()}, + {"int32", int32()}, + {"int64", int64()}, + {"uint8", uint8()}, + {"uint16", uint16()}, + {"uint32", uint32()}, + {"uint64", uint64()}, + {"float16", float16()}, + {"float32", float32()}, + {"utf8", utf8()}, + {"large_utf8", large_utf8()}, + {"binary", binary()}, + {"large_binary", large_binary()}, + {"date32", date32()}, + {"date64", date64()}, + {"day_time_interval", day_time_interval()}, + {"month_interval", month_interval()}, + {"month_day_nano_interval", month_day_nano_interval()}, +}; + +static Result> ParseDataType(std::string_view& type); + +// Takes the args list not including the enclosing parentheses +using InstantiateTypeFn = + std::add_pointer_t>(std::string_view&)>; + +static Result ParseInt32(std::string_view& args) { + ConsumeWhitespace(&args); + int32_t result; + auto [finish, ec] = std::from_chars(args.data(), args.data() + args.size(), result); + if (ec == std::errc::invalid_argument) + return Status::Invalid("Could not parse ", args, " as an int32!"); + + args.remove_prefix(finish - args.data()); + return result; +} + +static Status ParseComma(std::string_view& args) { + ConsumeWhitespace(&args); + if (args.empty() || args[0] != ',') + return Status::Invalid("Expected comma-separated args list near ", args); + args.remove_prefix(1); + return Status::OK(); +} + +static Result> ParseFixedSizeBinary(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(int32_t byte_width, ParseInt32(args)); + return fixed_size_binary(byte_width); +} + +static Result> ParseDecimalArgs(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(int32_t precision, ParseInt32(args)); + RETURN_NOT_OK(ParseComma(args)); + ARROW_ASSIGN_OR_RAISE(int32_t scale, ParseInt32(args)); + return std::pair(precision, scale); +} + +static Result> ParseDecimal(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(auto ps, ParseDecimalArgs(args)); + return decimal(ps.first, ps.second); +} + +static Result> ParseDecimal128(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(auto ps, ParseDecimalArgs(args)); + return decimal128(ps.first, ps.second); +} + +static Result> ParseDecimal256(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(auto ps, ParseDecimalArgs(args)); + return decimal256(ps.first, ps.second); +} + +static Result> ParseList(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr list_type, ParseDataType(args)); + return list(std::move(list_type)); +} + +static Result> ParseLargeList(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr list_type, ParseDataType(args)); + return large_list(std::move(list_type)); +} + +static Result> ParseMap(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr key_type, ParseDataType(args)); + RETURN_NOT_OK(ParseComma(args)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr value_type, ParseDataType(args)); + return map(std::move(key_type), std::move(value_type)); +} + +static Result> ParseFixedSizeList(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr list_type, ParseDataType(args)); + RETURN_NOT_OK(ParseComma(args)); + ARROW_ASSIGN_OR_RAISE(int32_t size, ParseInt32(args)); + return fixed_size_list(std::move(list_type), size); +} + +static Result ParseTimeUnit(std::string_view& args) { + ConsumeWhitespace(&args); + if (args.empty()) return Status::Invalid("Expected a time unit near ", args); + + const std::string_view options[4] = {"SECOND", "MILLI", "MICRO", "NANO"}; + for (size_t i = 0; i < 4; i++) { + if (args.find(options[i]) == 0) { + args.remove_prefix(options[i].size()); + return TimeUnit::values()[i]; + } + } + return Status::Invalid("Unrecognized TimeUnit ", args); +} + +static Result> ParseDuration(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(TimeUnit::type unit, ParseTimeUnit(args)); + return duration(unit); +} + +static Result> ParseTimestamp(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(TimeUnit::type unit, ParseTimeUnit(args)); + return timestamp(unit); +} + +static Result> ParseTime32(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(TimeUnit::type unit, ParseTimeUnit(args)); + return time32(unit); +} + +static Result> ParseTime64(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(TimeUnit::type unit, ParseTimeUnit(args)); + return time64(unit); +} + +static Result> ParseDictionary(std::string_view& args) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr key_type, ParseDataType(args)); + RETURN_NOT_OK(ParseComma(args)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr value_type, ParseDataType(args)); + return dictionary(std::move(key_type), std::move(value_type)); +} + +static const std::unordered_map + kNameToParameterizedType = { + {"fixed_size_binary", ParseFixedSizeBinary}, + {"decimal", ParseDecimal}, + {"decimal128", ParseDecimal128}, + {"decimal256", ParseDecimal256}, + {"list", ParseList}, + {"large_list", ParseLargeList}, + {"map", ParseMap}, + {"fixed_size_list", ParseFixedSizeList}, + {"duration", ParseDuration}, + {"timestamp", ParseTimestamp}, + {"time32", ParseTime32}, + {"time64", ParseTime64}, + {"dictionary", ParseDictionary}, +}; + +static Result ParseExpr(std::string_view& expr); + +static Result ParseCall(std::string_view& expr) { + ConsumeWhitespace(&expr); + if (expr.empty()) return Status::Invalid("Found empty expression"); + + std::string_view function_name = ExtractUntil(&expr, "("); + if (expr.empty()) + return Status::Invalid("Expected argument list after function name", function_name); + expr.remove_prefix(1); // Remove the open paren + + std::vector args; + do { + ConsumeWhitespace(&expr); + if (expr.empty()) + return Status::Invalid("Found unterminated expression argument list"); + if (expr[0] == ')') break; + if (!args.empty()) RETURN_NOT_OK(ParseComma(expr)); + + ARROW_ASSIGN_OR_RAISE(Expression arg, ParseExpr(expr)); + args.emplace_back(std::move(arg)); + } while (true); + + expr.remove_prefix(1); // Remove the close paren + return call(std::string(function_name), std::move(args)); +} + +static Result ParseFieldRef(std::string_view& expr) { + if (expr.empty()) return Status::Invalid("Found an empty named fieldref"); + + std::string_view dot_path = ExtractArgument(&expr); + ARROW_ASSIGN_OR_RAISE(FieldRef field, FieldRef::FromDotPath(dot_path)); + return field_ref(std::move(field)); +} + +static Result> ParseParameterizedDataType( + std::string_view& type) { + size_t lparen = type.find_first_of("("); + if (lparen == std::string_view::npos) return Status::Invalid("Unknown type ", type); + + std::string_view base_type_name = type.substr(0, lparen); + type.remove_prefix(lparen + 1); + auto it = kNameToParameterizedType.find(base_type_name); + if (it == kNameToParameterizedType.end()) + return Status::Invalid("Unknown base type name ", base_type_name); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr parsed_type, it->second(type)); + ConsumeWhitespace(&type); + if (type.empty() || type[0] != ')') + return Status::Invalid("Unterminated data type arg list!"); + type.remove_prefix(1); + return parsed_type; +} + +static Result> ParseDataType(std::string_view& type) { + auto it = kNameToSimpleType.find(type); + if (it == kNameToSimpleType.end()) return ParseParameterizedDataType(type); + return it->second; +} + +static Result ParseLiteral(std::string_view& expr) { + ARROW_DCHECK(expr[0] == '$'); + expr.remove_prefix(1); + size_t colon = expr.find_first_of(":"); + std::string_view type_name = expr.substr(0, colon); + expr.remove_prefix(colon); + if (expr.empty()) return Status::Invalid("Found an unterminated literal!"); + + ARROW_DCHECK_EQ(expr[0], ':'); + expr.remove_prefix(1); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr type, ParseDataType(type_name)); + + std::string_view value = TrimUntilNextSeparator(&expr); + + std::shared_ptr scalar; + if (value.find('\\') == value.npos) { + ARROW_ASSIGN_OR_RAISE(scalar, Scalar::Parse(type, value)); + } else { + std::string escaped = ProcessEscapes(value); + ARROW_ASSIGN_OR_RAISE(scalar, Scalar::Parse(type, escaped)); + } + return literal(std::move(scalar)); +} + +static Result ParseExpr(std::string_view& expr) { + ConsumeWhitespace(&expr); + if (expr.empty()) return Status::Invalid("Expression is empty!"); + switch (expr[0]) { + case '.': + case '[': + return ParseFieldRef(expr); + case '$': + return ParseLiteral(expr); + default: + return ParseCall(expr); + } +} + +static Status FormatErrorMessage(std::string_view original_expr, std::string_view parsed, + const Status& st) { + ssize_t error_idx = static_cast(original_expr.size() - parsed.size()); + + constexpr ssize_t kCharactersBehindPreview = 5; + constexpr ssize_t kCharactersPreview = 15; + ssize_t preview_start = + std::max(static_cast(0), error_idx - kCharactersBehindPreview); + ssize_t preview_end = std::min(preview_start + kCharactersPreview, + static_cast(original_expr.size())); + std::string new_error = + "Error at index " + std::to_string(error_idx) + ": " + st.message() + "\n"; + if (preview_start > 0) new_error += "..."; + new_error += original_expr.substr(preview_start, preview_end); + if (preview_end < static_cast(original_expr.size())) new_error += "..."; + return st.WithMessage(std::move(new_error)); +} + +Result Expression::FromString(std::string_view expr) { + std::string_view original = expr; + Result parsed = ParseExpr(expr); + if (!parsed.ok()) return FormatErrorMessage(original, expr, parsed.status()); + return parsed; +} +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index cc31735512b..224580e542b 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1235,15 +1235,14 @@ void FieldRef::Flatten(std::vector children) { } } -Result FieldRef::FromDotPath(const std::string& dot_path_arg) { - if (dot_path_arg.empty()) { - return FieldRef(); +Result FieldRef::FromDotPath(std::string_view dot_path) { + std::string_view original_dot_path = dot_path; + if (dot_path.empty()) { + return Status::Invalid("Dot path was empty"); } std::vector children; - std::string_view dot_path = dot_path_arg; - auto parse_name = [&] { std::string name; for (;;) { @@ -1289,7 +1288,7 @@ Result FieldRef::FromDotPath(const std::string& dot_path_arg) { case '[': { auto subscript_end = dot_path.find_first_not_of("0123456789"); if (subscript_end == std::string_view::npos || dot_path[subscript_end] != ']') { - return Status::Invalid("Dot path '", dot_path_arg, + return Status::Invalid("Dot path '", original_dot_path, "' contained an unterminated index"); } children.emplace_back(std::atoi(dot_path.data())); @@ -1297,8 +1296,8 @@ Result FieldRef::FromDotPath(const std::string& dot_path_arg) { continue; } default: - return Status::Invalid("Dot path must begin with '[' or '.', got '", dot_path_arg, - "'"); + return Status::Invalid("Dot path must begin with '[' or '.', got '", + original_dot_path, "'"); } } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 4bf8fe7fabb..9898399b783 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1723,7 +1723,7 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable { /// Note: When parsing a name, a '\' preceding any other character will be dropped from /// the resulting name. Therefore if a name must contain the characters '.', '\', or '[' /// those must be escaped with a preceding '\'. - static Result FromDotPath(const std::string& dot_path); + static Result FromDotPath(const std::string_view dot_path); std::string ToDotPath() const; bool Equals(const FieldRef& other) const { return impl_ == other.impl_; }