diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 4422e17e85c..1767c05b5ee 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1785,7 +1785,12 @@ macro(build_substrait) # Note: not all protos in Substrait actually matter to plan # consumption. No need to build the ones we don't need. - set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type) + set(SUBSTRAIT_PROTOS + algebra + extended_expression + extensions/extensions + plan + type) set(ARROW_SUBSTRAIT_PROTOS extension_rels) set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto") diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 13bf6f85a48..232638b7fc7 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -223,6 +223,16 @@ CastOptions::CastOptions(bool safe) allow_float_truncate(!safe), allow_invalid_utf8(!safe) {} +bool CastOptions::is_safe() const { + return !allow_int_overflow && !allow_time_truncate && !allow_time_overflow && + !allow_decimal_truncate && !allow_float_truncate && !allow_invalid_utf8; +} + +bool CastOptions::is_unsafe() const { + return allow_int_overflow && allow_time_truncate && allow_time_overflow && + allow_decimal_truncate && allow_float_truncate && allow_invalid_utf8; +} + constexpr char CastOptions::kTypeName[]; Result Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 7432933a124..613e8a55add 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -69,6 +69,15 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { // Indicate if conversions from Binary/FixedSizeBinary to string must // validate the utf8 payload. bool allow_invalid_utf8; + + /// true if the safety options all match CastOptions::Safe + /// + /// Note, if this returns false it does not mean is_unsafe will return true + bool is_safe() const; + /// true if the safety options all match CastOptions::Unsafe + /// + /// Note, if this returns false it does not mean is_safe will return true + bool is_unsafe() const; }; /// @} diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index 7494be8ebb1..fcaa242b114 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -21,6 +21,7 @@ arrow_install_all_headers("arrow/engine") set(ARROW_SUBSTRAIT_SRCS substrait/expression_internal.cc + substrait/extended_expression_internal.cc substrait/extension_set.cc substrait/extension_types.cc substrait/options.cc diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 5e214bdda4d..0df8425609f 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -41,6 +41,7 @@ #include "arrow/buffer.h" #include "arrow/builder.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" #include "arrow/compute/expression.h" #include "arrow/compute/expression_internal.h" #include "arrow/engine/substrait/extension_set.h" @@ -338,6 +339,22 @@ Result FromProto(const substrait::Expression& expr, return compute::call("case_when", std::move(args)); } + case substrait::Expression::kSingularOrList: { + const auto& or_list = expr.singular_or_list(); + + ARROW_ASSIGN_OR_RAISE(compute::Expression value, + FromProto(or_list.value(), ext_set, conversion_options)); + + std::vector option_eqs; + for (const auto& option : or_list.options()) { + ARROW_ASSIGN_OR_RAISE(compute::Expression arrow_option, + FromProto(option, ext_set, conversion_options)); + option_eqs.push_back(compute::call("equal", {value, arrow_option})); + } + + return compute::or_(option_eqs); + } + case substrait::Expression::kScalarFunction: { const auto& scalar_fn = expr.scalar_function(); @@ -1055,9 +1072,68 @@ Result> EncodeSubstraitCa " arguments but no argument could be found at index ", i); } } + + for (const auto& option : call.options()) { + substrait::FunctionOption* fn_option = scalar_fn->add_options(); + fn_option->set_name(option.first); + for (const auto& opt_val : option.second) { + std::string* pref = fn_option->add_preference(); + *pref = opt_val; + } + } + return std::move(scalar_fn); } +Result>> DatumToLiterals( + const Datum& datum, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + std::vector> literals; + + auto ScalarToLiteralExpr = [&](const std::shared_ptr& scalar) + -> Result> { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr literal, + ToProto(scalar, ext_set, conversion_options)); + auto literal_expr = std::make_unique(); + literal_expr->set_allocated_literal(literal.release()); + return literal_expr; + }; + + switch (datum.kind()) { + case Datum::Kind::SCALAR: { + ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(datum.scalar())); + literals.push_back(std::move(literal_expr)); + break; + } + case Datum::Kind::ARRAY: { + std::shared_ptr values = datum.make_array(); + for (int64_t i = 0; i < values->length(); i++) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, values->GetScalar(i)); + ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(scalar)); + literals.push_back(std::move(literal_expr)); + } + break; + } + case Datum::Kind::CHUNKED_ARRAY: { + std::shared_ptr values = datum.chunked_array(); + for (const auto& chunk : values->chunks()) { + for (int64_t i = 0; i < chunk->length(); i++) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, chunk->GetScalar(i)); + ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(scalar)); + literals.push_back(std::move(literal_expr)); + } + } + break; + } + case Datum::Kind::RECORD_BATCH: + case Datum::Kind::TABLE: + case Datum::Kind::NONE: + return Status::Invalid("Expected a literal or an array of literals, got ", + datum.ToString()); + } + return literals; +} + Result> ToProto( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { @@ -1164,15 +1240,89 @@ Result> ToProto( out->set_allocated_if_then(if_then.release()); return std::move(out); + } else if (call->function_name == "cast") { + auto cast = std::make_unique(); + + // Arrow's cast function does not have a "return null" option and so throw exception + // is the only behavior we can support. + cast->set_failure_behavior( + substrait::Expression::Cast::FAILURE_BEHAVIOR_THROW_EXCEPTION); + + std::shared_ptr cast_options = + internal::checked_pointer_cast(call->options); + if (!cast_options->is_unsafe()) { + return Status::Invalid("Substrait is only capable of representing unsafe casts"); + } + + if (arguments.size() != 1) { + return Status::Invalid( + "A call to the cast function must have exactly one argument"); + } + + cast->set_allocated_input(arguments[0].release()); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr to_type, + ToProto(*cast_options->to_type.type, /*nullable=*/true, ext_set, + conversion_options)); + + cast->set_allocated_type(to_type.release()); + + out->set_allocated_cast(cast.release()); + return std::move(out); + } else if (call->function_name == "is_in") { + auto or_list = std::make_unique(); + + if (arguments.size() != 1) { + return Status::Invalid( + "A call to the is_in function must have exactly one argument"); + } + + or_list->set_allocated_value(arguments[0].release()); + std::shared_ptr is_in_options = + internal::checked_pointer_cast(call->options); + + // TODO(GH-36420) Acero does not currently handle nulls correctly + ARROW_ASSIGN_OR_RAISE( + std::vector> options, + DatumToLiterals(is_in_options->value_set, ext_set, conversion_options)); + for (auto& option : options) { + or_list->mutable_options()->AddAllocated(option.release()); + } + out->set_allocated_singular_or_list(or_list.release()); + return std::move(out); } // other expression types dive into extensions immediately - ARROW_ASSIGN_OR_RAISE( - ExtensionIdRegistry::ArrowToSubstraitCall converter, - ext_set->registry()->GetArrowToSubstraitCall(call->function_name)); - ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr scalar_fn, - EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); + Result maybe_converter = + ext_set->registry()->GetArrowToSubstraitCall(call->function_name); + + ExtensionIdRegistry::ArrowToSubstraitCall converter; + std::unique_ptr scalar_fn; + if (maybe_converter.ok()) { + converter = *maybe_converter; + ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call)); + ARROW_ASSIGN_OR_RAISE( + scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); + } else if (maybe_converter.status().IsNotImplemented() && + conversion_options.allow_arrow_extensions) { + if (call->options) { + return Status::NotImplemented( + "The function ", call->function_name, + " has no Substrait mapping. Arrow extensions are enabled but the call " + "contains function options and there is no current mechanism to encode those."); + } + Id persistent_id = ext_set->RegisterPlanSpecificId( + {kArrowSimpleExtensionFunctionsUri, call->function_name}); + SubstraitCall substrait_call(persistent_id, call->type.GetSharedPtr(), + /*nullable=*/true); + for (int i = 0; i < static_cast(call->arguments.size()); i++) { + substrait_call.SetValueArg(i, call->arguments[i]); + } + ARROW_ASSIGN_OR_RAISE( + scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); + } else { + return maybe_converter.status(); + } out->set_allocated_scalar_function(scalar_fn.release()); return std::move(out); } diff --git a/cpp/src/arrow/engine/substrait/extended_expression_internal.cc b/cpp/src/arrow/engine/substrait/extended_expression_internal.cc new file mode 100644 index 00000000000..a6401e1d0b3 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extended_expression_internal.cc @@ -0,0 +1,210 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#include "arrow/engine/substrait/extended_expression_internal.h" + +#include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/relation_internal.h" +#include "arrow/engine/substrait/type_internal.h" +#include "arrow/engine/substrait/util.h" +#include "arrow/engine/substrait/util_internal.h" +#include "arrow/status.h" +#include "arrow/util/iterator.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace engine { + +namespace { +Result GetExtensionSetFromExtendedExpression( + const substrait::ExtendedExpression& expr, + const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry) { + return GetExtensionSetFromMessage(expr, conversion_options, registry); +} + +Status AddExtensionSetToExtendedExpression(const ExtensionSet& ext_set, + substrait::ExtendedExpression* expr) { + return AddExtensionSetToMessage(ext_set, expr); +} + +Status VisitNestedFields(const DataType& type, + std::function visitor) { + if (!is_nested(type.id())) { + return Status::OK(); + } + for (const auto& field : type.fields()) { + ARROW_RETURN_NOT_OK(VisitNestedFields(*field->type(), visitor)); + ARROW_RETURN_NOT_OK(visitor(*field)); + } + return Status::OK(); +} + +Result ExpressionFromProto( + const substrait::ExpressionReference& expression, const Schema& input_schema, + const ExtensionSet& ext_set, const ConversionOptions& conversion_options, + const ExtensionIdRegistry* registry) { + NamedExpression named_expr; + switch (expression.expr_type_case()) { + case substrait::ExpressionReference::ExprTypeCase::kExpression: { + ARROW_ASSIGN_OR_RAISE( + named_expr.expression, + FromProto(expression.expression(), ext_set, conversion_options)); + break; + } + case substrait::ExpressionReference::ExprTypeCase::kMeasure: { + return Status::NotImplemented("ExtendedExpression containing aggregate functions"); + } + default: { + return Status::Invalid( + "Unrecognized substrait::ExpressionReference::ExprTypeCase: ", + expression.expr_type_case()); + } + } + + ARROW_ASSIGN_OR_RAISE(named_expr.expression, named_expr.expression.Bind(input_schema)); + const DataType& output_type = *named_expr.expression.type(); + + // An expression reference has the entire DFS tree of field names for the output type + // which is usually redundant. Then it has one extra name for the name of the + // expression which is not redundant. + // + // For example, if the base schema is [struct, i32] and the expression is + // field(0) the the extended expression output names might be ["foo", "my_expression"]. + // The "foo" is redundant but we can verify it matches and reject if it does not. + // + // The one exception is struct literals which have no field names. For example, if + // the base schema is [i32, i64] and the expression is {7, 3}_struct then the + // output type is struct and we do not know the names of the output type. + // + // TODO(weston) we could patch the names back in at this point using the output + // names field but this is rather complex and it might be easier to give names to + // struct literals in Substrait. + int output_name_idx = 0; + ARROW_RETURN_NOT_OK(VisitNestedFields(output_type, [&](const Field& field) { + if (output_name_idx >= expression.output_names_size()) { + return Status::Invalid("Ambiguous plan. Expression had ", + expression.output_names_size(), + " output names but the field in base_schema had type ", + output_type.ToString(), " which needs more output names"); + } + if (!field.name().empty() && + field.name() != expression.output_names(output_name_idx)) { + return Status::Invalid("Ambiguous plan. Expression had output type ", + output_type.ToString(), + " which contains a nested field named ", field.name(), + " but the output_names in the Substrait message contained ", + expression.output_names(output_name_idx)); + } + output_name_idx++; + return Status::OK(); + })); + // The last name is the actual field name that we can't verify but there should only + // be one extra name. + if (output_name_idx < expression.output_names_size() - 1) { + return Status::Invalid("Ambiguous plan. Expression had ", + expression.output_names_size(), + " output names but the field in base_schema had type ", + output_type.ToString(), " which doesn't have enough fields"); + } + if (expression.output_names_size() == 0) { + // This is potentially invalid substrait but we can handle it + named_expr.name = ""; + } else { + named_expr.name = expression.output_names(expression.output_names_size() - 1); + } + return named_expr; +} + +Result> CreateExpressionReference( + const std::string& name, const Expression& expr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto expr_ref = std::make_unique(); + ARROW_RETURN_NOT_OK(VisitNestedFields(*expr.type(), [&](const Field& field) { + expr_ref->add_output_names(field.name()); + return Status::OK(); + })); + expr_ref->add_output_names(name); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr expression, + ToProto(expr, ext_set, conversion_options)); + expr_ref->set_allocated_expression(expression.release()); + return std::move(expr_ref); +} + +} // namespace + +Result FromProto(const substrait::ExtendedExpression& expression, + ExtensionSet* ext_set_out, + const ConversionOptions& conversion_options, + const ExtensionIdRegistry* registry) { + BoundExpressions bound_expressions; + ARROW_RETURN_NOT_OK(CheckVersion(expression.version().major_number(), + expression.version().minor_number())); + if (expression.has_advanced_extensions()) { + return Status::NotImplemented("Advanced extensions in ExtendedExpression"); + } + ARROW_ASSIGN_OR_RAISE( + ExtensionSet ext_set, + GetExtensionSetFromExtendedExpression(expression, conversion_options, registry)); + + ARROW_ASSIGN_OR_RAISE(bound_expressions.schema, + FromProto(expression.base_schema(), ext_set, conversion_options)); + + bound_expressions.named_expressions.reserve(expression.referred_expr_size()); + + for (const auto& referred_expr : expression.referred_expr()) { + ARROW_ASSIGN_OR_RAISE(NamedExpression named_expr, + ExpressionFromProto(referred_expr, *bound_expressions.schema, + ext_set, conversion_options, registry)); + bound_expressions.named_expressions.push_back(std::move(named_expr)); + } + + if (ext_set_out) { + *ext_set_out = std::move(ext_set); + } + + return std::move(bound_expressions); +} + +Result> ToProto( + const BoundExpressions& bound_expressions, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto expression = std::make_unique(); + expression->set_allocated_version(CreateVersion().release()); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr base_schema, + ToProto(*bound_expressions.schema, ext_set, conversion_options)); + expression->set_allocated_base_schema(base_schema.release()); + for (const auto& named_expression : bound_expressions.named_expressions) { + Expression bound_expr = named_expression.expression; + if (!bound_expr.IsBound()) { + // This will use the default function registry. Most of the time that will be fine. + // In the cases where this is not what the user wants then the user should make sure + // to pass in bound expressions. + ARROW_ASSIGN_OR_RAISE(bound_expr, bound_expr.Bind(*bound_expressions.schema)); + } + ARROW_ASSIGN_OR_RAISE(std::unique_ptr expr_ref, + CreateExpressionReference(named_expression.name, bound_expr, + ext_set, conversion_options)); + expression->mutable_referred_expr()->AddAllocated(expr_ref.release()); + } + RETURN_NOT_OK(AddExtensionSetToExtendedExpression(*ext_set, expression.get())); + return std::move(expression); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extended_expression_internal.h b/cpp/src/arrow/engine/substrait/extended_expression_internal.h new file mode 100644 index 00000000000..81bc4b87451 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extended_expression_internal.h @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/options.h" +#include "arrow/engine/substrait/relation.h" +#include "arrow/engine/substrait/visibility.h" +#include "arrow/result.h" +#include "arrow/status.h" + +#include "substrait/extended_expression.pb.h" // IWYU pragma: export + +namespace arrow { +namespace engine { + +/// Convert a Substrait ExtendedExpression to a vector of expressions and output names +ARROW_ENGINE_EXPORT +Result FromProto(const substrait::ExtendedExpression& expression, + ExtensionSet* ext_set_out, + const ConversionOptions& conversion_options, + const ExtensionIdRegistry* extension_id_registry); + +/// Convert a vector of expressions to a Substrait ExtendedExpression +ARROW_ENGINE_EXPORT +Result> ToProto( + const BoundExpressions& bound_expressions, ExtensionSet* ext_set, + const ConversionOptions& conversion_options); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index d89248383b7..b0dd6aeffbc 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -28,12 +28,15 @@ #include "arrow/engine/substrait/options.h" #include "arrow/type.h" #include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/logging.h" #include "arrow/util/string.h" namespace arrow { + +using internal::checked_pointer_cast; namespace engine { namespace { @@ -229,6 +232,8 @@ Status ExtensionSet::AddUri(Id id) { return Status::OK(); } +Id ExtensionSet::RegisterPlanSpecificId(Id id) { return plan_specific_ids_->Emplace(id); } + // Creates an extension set from the Substrait plan's top-level extensions block Result ExtensionSet::Make( std::unordered_map uris, @@ -873,11 +878,10 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic }; } -ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrait_fn_id) { +ExtensionIdRegistry::ArrowToSubstraitCall EncodeBasic(Id substrait_fn_id) { return [substrait_fn_id](const compute::Expression::Call& call) -> Result { - // nullable=true isn't quite correct but we don't know the nullability of - // the inputs + // nullable=true errs on the side of caution SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), /*nullable=*/true); for (std::size_t i = 0; i < call.arguments.size(); i++) { @@ -887,11 +891,31 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrai }; } +ExtensionIdRegistry::ArrowToSubstraitCall EncodeIsNull(Id substrait_fn_id) { + return + [substrait_fn_id](const compute::Expression::Call& call) -> Result { + if (call.options != nullptr) { + auto null_opts = checked_pointer_cast(call.options); + if (null_opts->nan_is_null) { + return Status::Invalid( + "Substrait does not support is_null with nan_is_null=true. You can use " + "is_null || is_nan instead"); + } + } + SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), + /*nullable=*/false); + for (std::size_t i = 0; i < call.arguments.size(); i++) { + substrait_call.SetValueArg(static_cast(i), call.arguments[i]); + } + return std::move(substrait_call); + }; +} + ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping( const std::string& function_name, int max_args) { return [function_name, max_args](const SubstraitCall& call) -> Result { - if (call.size() > max_args) { + if (max_args >= 0 && call.size() > max_args) { return Status::NotImplemented("Acero does not have a kernel for ", function_name, " that receives ", call.size(), " arguments"); } @@ -1033,14 +1057,15 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { // -------------- Substrait -> Arrow Functions ----------------- // Mappings with a _checked variant for (const auto& function_name : - {"add", "subtract", "multiply", "divide", "power", "sqrt", "abs"}) { + {"add", "subtract", "multiply", "divide", "negate", "power", "sqrt", "abs"}) { DCHECK_OK( AddSubstraitCallToArrow({kSubstraitArithmeticFunctionsUri, function_name}, DecodeOptionlessOverflowableArithmetic(function_name))); } - // Mappings without a _checked variant - for (const auto& function_name : {"exp", "sign"}) { + // Mappings either without a _checked variant or substrait has no overflow option + for (const auto& function_name : + {"exp", "sign", "cos", "sin", "tan", "acos", "asin", "atan", "atan2"}) { DCHECK_OK( AddSubstraitCallToArrow({kSubstraitArithmeticFunctionsUri, function_name}, DecodeOptionlessUncheckedArithmetic(function_name))); @@ -1096,6 +1121,21 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DCHECK_OK( AddSubstraitCallToArrow({kSubstraitBooleanFunctionsUri, "not"}, DecodeOptionlessBasicMapping("invert", /*max_args=*/1))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitArithmeticFunctionsUri, "bitwise_not"}, + DecodeOptionlessBasicMapping("bit_wise_not", /*max_args=*/1))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitArithmeticFunctionsUri, "bitwise_or"}, + DecodeOptionlessBasicMapping("bit_wise_or", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitArithmeticFunctionsUri, "bitwise_and"}, + DecodeOptionlessBasicMapping("bit_wise_and", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitArithmeticFunctionsUri, "bitwise_xor"}, + DecodeOptionlessBasicMapping("bit_wise_xor", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitComparisonFunctionsUri, "coalesce"}, + DecodeOptionlessBasicMapping("coalesce", /*max_args=*/-1))); DCHECK_OK(AddSubstraitCallToArrow({kSubstraitDatetimeFunctionsUri, "extract"}, DecodeTemporalExtractionMapping())); DCHECK_OK(AddSubstraitCallToArrow({kSubstraitStringFunctionsUri, "concat"}, @@ -1103,6 +1143,12 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DCHECK_OK( AddSubstraitCallToArrow({kSubstraitComparisonFunctionsUri, "is_null"}, DecodeOptionlessBasicMapping("is_null", /*max_args=*/1))); + DCHECK_OK( + AddSubstraitCallToArrow({kSubstraitComparisonFunctionsUri, "is_nan"}, + DecodeOptionlessBasicMapping("is_nan", /*max_args=*/1))); + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitComparisonFunctionsUri, "is_finite"}, + DecodeOptionlessBasicMapping("is_finite", /*max_args=*/1))); DCHECK_OK(AddSubstraitCallToArrow( {kSubstraitComparisonFunctionsUri, "is_not_null"}, DecodeOptionlessBasicMapping("is_valid", /*max_args=*/1))); @@ -1127,7 +1173,9 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { } // --------------- Arrow -> Substrait Functions --------------- - for (const auto& fn_name : {"add", "subtract", "multiply", "divide"}) { + // Functions with a _checked variant + for (const auto& fn_name : + {"add", "subtract", "multiply", "divide", "negate", "power", "abs"}) { Id fn_id{kSubstraitArithmeticFunctionsUri, fn_name}; DCHECK_OK(AddArrowToSubstraitCall( fn_name, EncodeOptionlessOverflowableArithmetic(fn_id))); @@ -1135,11 +1183,49 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { AddArrowToSubstraitCall(std::string(fn_name) + "_checked", EncodeOptionlessOverflowableArithmetic(fn_id))); } - // Comparison operators - for (const auto& fn_name : {"equal", "is_not_distinct_from"}) { - Id fn_id{kSubstraitComparisonFunctionsUri, fn_name}; - DCHECK_OK(AddArrowToSubstraitCall(fn_name, EncodeOptionlessComparison(fn_id))); - } + // Functions with no options... + // ...and the same name + for (const auto& fn_pair : std::vector>{ + {kSubstraitComparisonFunctionsUri, "equal"}, + {kSubstraitComparisonFunctionsUri, "not_equal"}, + {kSubstraitComparisonFunctionsUri, "is_not_distinct_from"}, + {kSubstraitComparisonFunctionsUri, "is_nan"}, + {kSubstraitComparisonFunctionsUri, "is_finite"}, + {kSubstraitComparisonFunctionsUri, "coalesce"}, + {kSubstraitArithmeticFunctionsUri, "sqrt"}, + {kSubstraitArithmeticFunctionsUri, "sign"}, + {kSubstraitArithmeticFunctionsUri, "exp"}, + {kSubstraitArithmeticFunctionsUri, "cos"}, + {kSubstraitArithmeticFunctionsUri, "sin"}, + {kSubstraitArithmeticFunctionsUri, "tan"}, + {kSubstraitArithmeticFunctionsUri, "acos"}, + {kSubstraitArithmeticFunctionsUri, "asin"}, + {kSubstraitArithmeticFunctionsUri, "atan"}, + {kSubstraitArithmeticFunctionsUri, "atan2"}}) { + Id fn_id{fn_pair.first, fn_pair.second}; + DCHECK_OK(AddArrowToSubstraitCall(std::string(fn_pair.second), EncodeBasic(fn_id))); + } + // ...and different names + for (const auto& fn_triple : + std::vector>{ + {kSubstraitComparisonFunctionsUri, "lt", "less"}, + {kSubstraitComparisonFunctionsUri, "gt", "greater"}, + {kSubstraitComparisonFunctionsUri, "lte", "less_equal"}, + {kSubstraitComparisonFunctionsUri, "gte", "greater_equal"}, + {kSubstraitComparisonFunctionsUri, "is_not_null", "is_valid"}, + {kSubstraitArithmeticFunctionsUri, "bitwise_and", "bit_wise_and"}, + {kSubstraitArithmeticFunctionsUri, "bitwise_not", "bit_wise_not"}, + {kSubstraitArithmeticFunctionsUri, "bitwise_or", "bit_wise_or"}, + {kSubstraitArithmeticFunctionsUri, "bitwise_xor", "bit_wise_xor"}, + {kSubstraitBooleanFunctionsUri, "and", "and_kleene"}, + {kSubstraitBooleanFunctionsUri, "or", "or_kleene"}, + {kSubstraitBooleanFunctionsUri, "not", "invert"}}) { + Id fn_id{std::get<0>(fn_triple), std::get<1>(fn_triple)}; + DCHECK_OK(AddArrowToSubstraitCall(std::get<2>(fn_triple), EncodeBasic(fn_id))); + } + + DCHECK_OK(AddArrowToSubstraitCall( + "is_null", EncodeIsNull({kSubstraitComparisonFunctionsUri, "is_null"}))); } }; diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 50e0b11943f..d9c0af081a5 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -131,6 +131,9 @@ class ARROW_ENGINE_EXPORT SubstraitCall { const std::shared_ptr& output_type() const { return output_type_; } bool output_nullable() const { return output_nullable_; } bool is_hash() const { return is_hash_; } + const std::unordered_map>& options() const { + return options_; + } bool HasEnumArg(int index) const; Result GetEnumArg(int index) const; @@ -427,6 +430,15 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// \return An anchor that can be used to refer to the function within a plan Result EncodeFunction(Id function_id); + /// \brief Stores a plan-specific id that is not known to the registry + /// + /// This is used when converting an Arrow execution plan to a Substrait plan. + /// + /// If the function is a UDF, something that wasn't known to the registry, + /// then we need long term storage of the function name (the ids are just + /// views) + Id RegisterPlanSpecificId(Id id); + /// \brief Return the number of custom functions in this extension set std::size_t num_functions() const { return functions_.size(); } diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index 0d66c5eea43..1e6f6efb2c7 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -106,7 +106,8 @@ struct ARROW_ENGINE_EXPORT ConversionOptions { : strictness(ConversionStrictness::BEST_EFFORT), named_table_provider(kDefaultNamedTableProvider), named_tap_provider(default_named_tap_provider()), - extension_provider(default_extension_provider()) {} + extension_provider(default_extension_provider()), + allow_arrow_extensions(false) {} /// \brief How strictly the converter should adhere to the structure of the input. ConversionStrictness strictness; @@ -123,6 +124,11 @@ struct ARROW_ENGINE_EXPORT ConversionOptions { /// /// The default behavior will provide for relations known to Arrow. std::shared_ptr extension_provider; + /// \brief If true then Arrow-specific types and functions will be allowed + /// + /// Set to false to create plans that are more likely to be compatible with non-Arrow + /// engines + bool allow_arrow_extensions; }; } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index ecee81e25ff..cc4806878c4 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -25,9 +25,10 @@ #include #include "arrow/compute/type_fwd.h" -#include "arrow/config.h" #include "arrow/engine/substrait/relation_internal.h" #include "arrow/engine/substrait/type_fwd.h" +#include "arrow/engine/substrait/util.h" +#include "arrow/engine/substrait/util_internal.h" #include "arrow/result.h" #include "arrow/util/checked_cast.h" #include "arrow/util/hashing.h" @@ -43,122 +44,15 @@ using internal::checked_cast; namespace engine { Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) { - plan->clear_extension_uris(); - - std::unordered_map map; - - auto uris = plan->mutable_extension_uris(); - uris->Reserve(static_cast(ext_set.uris().size())); - for (uint32_t anchor = 0; anchor < ext_set.uris().size(); ++anchor) { - auto uri = ext_set.uris().at(anchor); - if (uri.empty()) continue; - - auto ext_uri = std::make_unique(); - ext_uri->set_uri(std::string(uri)); - ext_uri->set_extension_uri_anchor(anchor); - uris->AddAllocated(ext_uri.release()); - - map[uri] = anchor; - } - - auto extensions = plan->mutable_extensions(); - extensions->Reserve(static_cast(ext_set.num_types() + ext_set.num_functions())); - - using ExtDecl = substrait::extensions::SimpleExtensionDeclaration; - - for (uint32_t anchor = 0; anchor < ext_set.num_types(); ++anchor) { - ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); - if (type_record.id.empty()) continue; - - auto ext_decl = std::make_unique(); - - auto type = std::make_unique(); - type->set_extension_uri_reference(map[type_record.id.uri]); - type->set_type_anchor(anchor); - type->set_name(std::string(type_record.id.name)); - ext_decl->set_allocated_extension_type(type.release()); - extensions->AddAllocated(ext_decl.release()); - } - - for (uint32_t anchor = 0; anchor < ext_set.num_functions(); ++anchor) { - ARROW_ASSIGN_OR_RAISE(Id function_id, ext_set.DecodeFunction(anchor)); - - auto fn = std::make_unique(); - fn->set_extension_uri_reference(map[function_id.uri]); - fn->set_function_anchor(anchor); - fn->set_name(std::string(function_id.name)); - - auto ext_decl = std::make_unique(); - ext_decl->set_allocated_extension_function(fn.release()); - extensions->AddAllocated(ext_decl.release()); - } - - return Status::OK(); + return AddExtensionSetToMessage(ext_set, plan); } Result GetExtensionSetFromPlan(const substrait::Plan& plan, const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry) { - if (registry == NULLPTR) { - registry = default_extension_id_registry(); - } - std::unordered_map uris; - uris.reserve(plan.extension_uris_size()); - for (const auto& uri : plan.extension_uris()) { - uris[uri.extension_uri_anchor()] = uri.uri(); - } - - // NOTE: it's acceptable to use views to memory owned by plan; ExtensionSet::Make - // will only store views to memory owned by registry. - - std::unordered_map type_ids, function_ids; - for (const auto& ext : plan.extensions()) { - switch (ext.mapping_type_case()) { - case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { - return Status::NotImplemented("Type Variations are not yet implemented"); - } - - case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { - const auto& type = ext.extension_type(); - std::string_view uri = uris[type.extension_uri_reference()]; - type_ids[type.type_anchor()] = Id{uri, type.name()}; - break; - } - - case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { - const auto& fn = ext.extension_function(); - std::string_view uri = uris[fn.extension_uri_reference()]; - function_ids[fn.function_anchor()] = Id{uri, fn.name()}; - break; - } - - default: - Unreachable(); - } - } - - return ExtensionSet::Make(std::move(uris), std::move(type_ids), std::move(function_ids), - conversion_options, registry); -} - -namespace { - -// TODO(ARROW-18145) Populate these from cmake files -constexpr uint32_t kSubstraitMajorVersion = 0; -constexpr uint32_t kSubstraitMinorVersion = 20; -constexpr uint32_t kSubstraitPatchVersion = 0; - -std::unique_ptr CreateVersion() { - auto version = std::make_unique(); - version->set_major_number(kSubstraitMajorVersion); - version->set_minor_number(kSubstraitMinorVersion); - version->set_patch_number(kSubstraitPatchVersion); - version->set_producer("Acero " + GetBuildInfo().version_string); - return version; + return GetExtensionSetFromMessage(plan, conversion_options, registry); } -} // namespace - Result> PlanToProto( const acero::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { diff --git a/cpp/src/arrow/engine/substrait/relation.h b/cpp/src/arrow/engine/substrait/relation.h index 0be4e03bb38..d0913b9ae02 100644 --- a/cpp/src/arrow/engine/substrait/relation.h +++ b/cpp/src/arrow/engine/substrait/relation.h @@ -20,6 +20,7 @@ #include #include "arrow/acero/exec_plan.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" @@ -50,5 +51,21 @@ struct ARROW_ENGINE_EXPORT PlanInfo { std::vector names; }; +/// An expression whose output has a name +struct ARROW_ENGINE_EXPORT NamedExpression { + /// An expression + compute::Expression expression; + // An optional name to assign to the output, may be the empty string + std::string name; +}; + +/// A collection of expressions bound to a common schema +struct ARROW_ENGINE_EXPORT BoundExpressions { + /// The expressions + std::vector named_expressions; + /// The schema that all the expressions are bound to + std::shared_ptr schema; +}; + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index b5effd78524..9e670f12177 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -35,12 +35,14 @@ #include "arrow/compute/expression.h" #include "arrow/dataset/file_base.h" #include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/extended_expression_internal.h" #include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/plan_internal.h" #include "arrow/engine/substrait/relation.h" #include "arrow/engine/substrait/relation_internal.h" #include "arrow/engine/substrait/type_fwd.h" #include "arrow/engine/substrait/type_internal.h" +#include "arrow/engine/substrait/util.h" #include "arrow/type.h" namespace arrow { @@ -74,6 +76,20 @@ Result> SerializePlan( return Buffer::FromString(std::move(serialized)); } +Result> SerializeExpressions( + const BoundExpressions& bound_expressions, + const ConversionOptions& conversion_options, ExtensionSet* ext_set) { + ExtensionSet throwaway_ext_set; + if (ext_set == nullptr) { + ext_set = &throwaway_ext_set; + } + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr extended_expression, + ToProto(bound_expressions, ext_set, conversion_options)); + std::string serialized = extended_expression->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + Result> SerializeRelation( const acero::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { @@ -125,20 +141,14 @@ DeclarationFactory MakeWriteDeclarationFactory( }; } -constexpr uint32_t kMinimumMajorVersion = 0; -constexpr uint32_t kMinimumMinorVersion = 20; - Result> DeserializePlans( const Buffer& buf, DeclarationFactory declaration_factory, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); - if (plan.version().major_number() < kMinimumMajorVersion && - plan.version().minor_number() < kMinimumMinorVersion) { - return Status::Invalid("Can only parse plans with a version >= ", - kMinimumMajorVersion, ".", kMinimumMinorVersion); - } + ARROW_RETURN_NOT_OK( + CheckVersion(plan.version().major_number(), plan.version().minor_number())); ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, conversion_options, registry)); @@ -196,12 +206,8 @@ ARROW_ENGINE_EXPORT Result DeserializePlan( const Buffer& buf, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); - - if (plan.version().major_number() < kMinimumMajorVersion && - plan.version().minor_number() < kMinimumMinorVersion) { - return Status::Invalid("Can only parse plans with a version >= ", - kMinimumMajorVersion, ".", kMinimumMinorVersion); - } + ARROW_RETURN_NOT_OK( + CheckVersion(plan.version().major_number(), plan.version().minor_number())); ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, conversion_options, registry)); @@ -233,6 +239,14 @@ ARROW_ENGINE_EXPORT Result DeserializePlan( return PlanInfo{std::move(decl_info), std::move(names)}; } +Result DeserializeExpressions( + const Buffer& buf, const ExtensionIdRegistry* registry, + const ConversionOptions& conversion_options, ExtensionSet* ext_set_out) { + ARROW_ASSIGN_OR_RAISE(auto extended_expression, + ParseFromBuffer(buf)); + return FromProto(extended_expression, ext_set_out, conversion_options, registry); +} + namespace { Result> MakeSingleDeclarationPlan( diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index ebbefb176e2..ab749f4a64b 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -52,6 +52,19 @@ Result> SerializePlan( const acero::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); +/// \brief Serialize expressions to a Substrait message +/// +/// \param[in] bound_expressions the expressions to serialize. +/// \param[in] conversion_options options to control how the conversion is done +/// \param[in,out] ext_set the extension mapping to use, optional, only needed +/// if you want to control the value of function anchors +/// to mirror a previous serialization / deserialization. +/// Will be updated if new functions are encountered +ARROW_ENGINE_EXPORT +Result> SerializeExpressions( + const BoundExpressions& bound_expressions, + const ConversionOptions& conversion_options = {}, ExtensionSet* ext_set = NULLPTR); + /// Factory function type for generating the node that consumes the batches produced by /// each toplevel Substrait relation when deserializing a Substrait Plan. using ConsumerFactory = std::function()>; @@ -155,6 +168,21 @@ ARROW_ENGINE_EXPORT Result DeserializePlan( ExtensionSet* ext_set_out = NULLPTR, const ConversionOptions& conversion_options = {}); +/// \brief Deserialize a Substrait ExtendedExpression message to the corresponding Arrow +/// type +/// +/// \param[in] buf a buffer containing the protobuf serialization of a collection of bound +/// expressions +/// \param[in] registry an extension-id-registry to use, or null for the default one +/// \param[in] conversion_options options to control how the conversion is done +/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait +/// message is returned here. +/// \return A collection of expressions and a common input schema they are bound to +ARROW_ENGINE_EXPORT Result DeserializeExpressions( + const Buffer& buf, const ExtensionIdRegistry* registry = NULLPTR, + const ConversionOptions& conversion_options = {}, + ExtensionSet* ext_set_out = NULLPTR); + /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index efe1f702b48..2e72ae70edd 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -6070,5 +6070,100 @@ TEST(Substrait, PlanWithSegmentedAggregateExtension) { CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); } +void CheckExpressionRoundTrip(const Schema& schema, + const compute::Expression& expression) { + ASSERT_OK_AND_ASSIGN(compute::Expression bound_expression, expression.Bind(schema)); + BoundExpressions bound_expressions; + bound_expressions.schema = std::make_shared(schema); + bound_expressions.named_expressions = {{std::move(bound_expression), "some_name"}}; + + ASSERT_OK_AND_ASSIGN(std::shared_ptr buf, + SerializeExpressions(bound_expressions)); + + ASSERT_OK_AND_ASSIGN(BoundExpressions round_tripped, DeserializeExpressions(*buf)); + + AssertSchemaEqual(schema, *round_tripped.schema); + ASSERT_EQ(1, round_tripped.named_expressions.size()); + ASSERT_EQ("some_name", round_tripped.named_expressions[0].name); + ASSERT_EQ(bound_expressions.named_expressions[0].expression, + round_tripped.named_expressions[0].expression); +} + +TEST(Substrait, ExtendedExpressionSerialization) { + std::shared_ptr test_schema = + schema({field("a", int32()), field("b", int32()), field("c", float32()), + field("nested", struct_({field("x", float32()), field("y", float32())}))}); + // Basic a + b + CheckExpressionRoundTrip( + *test_schema, compute::call("add", {compute::field_ref(0), compute::field_ref(1)})); + // Nested struct reference + CheckExpressionRoundTrip(*test_schema, compute::field_ref(FieldPath{3, 0})); + // Struct return type + CheckExpressionRoundTrip(*test_schema, compute::field_ref(3)); + // c + nested.y + CheckExpressionRoundTrip( + *test_schema, + compute::call("add", {compute::field_ref(2), compute::field_ref(FieldPath{3, 1})})); +} + +TEST(Substrait, ExtendedExpressionInvalidPlans) { + // The schema defines the type as {"x", "y"} but output_names has {"a", "y"} + constexpr std::string_view kBadOuptutNames = R"( + { + "referredExpr":[ + { + "expression":{ + "selection":{ + "directReference":{ + "structField":{ + "field":3 + } + }, + "rootReference":{} + } + }, + "outputNames":["a", "y", "some_name"] + } + ], + "baseSchema":{ + "names":["a","b","c","nested","x","y"], + "struct":{ + "types":[ + { + "i32":{"nullability":"NULLABILITY_NULLABLE"} + }, + { + "i32":{"nullability":"NULLABILITY_NULLABLE"} + }, + { + "fp32":{"nullability":"NULLABILITY_NULLABLE"} + }, + { + "struct":{ + "types":[ + { + "fp32":{"nullability":"NULLABILITY_NULLABLE"} + }, + { + "fp32":{"nullability":"NULLABILITY_NULLABLE"} + } + ], + "nullability":"NULLABILITY_NULLABLE" + } + } + ] + } + }, + "version":{"majorNumber":9999} + } + )"; + + ASSERT_OK_AND_ASSIGN( + auto buf, internal::SubstraitFromJSON("ExtendedExpression", kBadOuptutNames)); + + ASSERT_THAT(DeserializeExpressions(*buf), + Raises(StatusCode::Invalid, testing::HasSubstr("Ambiguous plan"))); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index b74e333fd97..d842d0ef9d7 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -70,6 +70,16 @@ const std::string& default_extension_types_uri() { return uri; } +Status CheckVersion(uint32_t major_version, uint32_t minor_version) { + if (major_version < kSubstraitMinimumMajorVersion && + minor_version < kSubstraitMinimumMinorVersion) { + return Status::Invalid("Can only parse Substrait messages with a version >= ", + kSubstraitMinimumMajorVersion, ".", + kSubstraitMinimumMinorVersion); + } + return Status::OK(); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 9f8bd804889..5128ec44bff 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -68,6 +68,16 @@ ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry ARROW_ENGINE_EXPORT const std::string& default_extension_types_uri(); +// TODO(ARROW-18145) Populate these from cmake files +constexpr uint32_t kSubstraitMajorVersion = 0; +constexpr uint32_t kSubstraitMinorVersion = 27; +constexpr uint32_t kSubstraitPatchVersion = 0; + +constexpr uint32_t kSubstraitMinimumMajorVersion = 0; +constexpr uint32_t kSubstraitMinimumMinorVersion = 20; + +Status CheckVersion(uint32_t major_version, uint32_t minor_version); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util_internal.cc b/cpp/src/arrow/engine/substrait/util_internal.cc index 4e6cacf4f67..89034784ab5 100644 --- a/cpp/src/arrow/engine/substrait/util_internal.cc +++ b/cpp/src/arrow/engine/substrait/util_internal.cc @@ -17,6 +17,9 @@ #include "arrow/engine/substrait/util_internal.h" +#include "arrow/config.h" +#include "arrow/engine/substrait/util.h" + namespace arrow { namespace engine { @@ -30,6 +33,15 @@ std::string EnumToString(int value, const google::protobuf::EnumDescriptor& desc return value_desc->name(); } +std::unique_ptr CreateVersion() { + auto version = std::make_unique(); + version->set_major_number(kSubstraitMajorVersion); + version->set_minor_number(kSubstraitMinorVersion); + version->set_patch_number(kSubstraitPatchVersion); + version->set_producer("Acero " + GetBuildInfo().version_string); + return version; +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util_internal.h b/cpp/src/arrow/engine/substrait/util_internal.h index efc3145543d..627ad1126df 100644 --- a/cpp/src/arrow/engine/substrait/util_internal.h +++ b/cpp/src/arrow/engine/substrait/util_internal.h @@ -17,8 +17,17 @@ #pragma once +#include + +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/visibility.h" +#include "arrow/result.h" +#include "arrow/util/hashing.h" +#include "arrow/util/unreachable.h" + #include "substrait/algebra.pb.h" // IWYU pragma: export +#include "substrait/plan.pb.h" // IWYU pragma: export namespace arrow { namespace engine { @@ -26,5 +35,110 @@ namespace engine { ARROW_ENGINE_EXPORT std::string EnumToString( int value, const google::protobuf::EnumDescriptor& descriptor); +// Extension sets can be present in both substrait::Plan and substrait::ExtendedExpression +// and so this utility is templated to support both. +template +Result GetExtensionSetFromMessage( + const MessageType& message, const ConversionOptions& conversion_options, + const ExtensionIdRegistry* registry) { + if (registry == NULLPTR) { + registry = default_extension_id_registry(); + } + std::unordered_map uris; + uris.reserve(message.extension_uris_size()); + for (const auto& uri : message.extension_uris()) { + uris[uri.extension_uri_anchor()] = uri.uri(); + } + + // NOTE: it's acceptable to use views to memory owned by message; ExtensionSet::Make + // will only store views to memory owned by registry. + + std::unordered_map type_ids, function_ids; + for (const auto& ext : message.extensions()) { + switch (ext.mapping_type_case()) { + case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { + return Status::NotImplemented("Type Variations are not yet implemented"); + } + + case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { + const auto& type = ext.extension_type(); + std::string_view uri = uris[type.extension_uri_reference()]; + type_ids[type.type_anchor()] = Id{uri, type.name()}; + break; + } + + case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + const auto& fn = ext.extension_function(); + std::string_view uri = uris[fn.extension_uri_reference()]; + function_ids[fn.function_anchor()] = Id{uri, fn.name()}; + break; + } + + default: + Unreachable(); + } + } + + return ExtensionSet::Make(std::move(uris), std::move(type_ids), std::move(function_ids), + conversion_options, registry); +} + +template +Status AddExtensionSetToMessage(const ExtensionSet& ext_set, Message* message) { + message->clear_extension_uris(); + + std::unordered_map map; + + auto uris = message->mutable_extension_uris(); + uris->Reserve(static_cast(ext_set.uris().size())); + for (uint32_t anchor = 0; anchor < ext_set.uris().size(); ++anchor) { + auto uri = ext_set.uris().at(anchor); + if (uri.empty()) continue; + + auto ext_uri = std::make_unique(); + ext_uri->set_uri(std::string(uri)); + ext_uri->set_extension_uri_anchor(anchor); + uris->AddAllocated(ext_uri.release()); + + map[uri] = anchor; + } + + auto extensions = message->mutable_extensions(); + extensions->Reserve(static_cast(ext_set.num_types() + ext_set.num_functions())); + + using ExtDecl = substrait::extensions::SimpleExtensionDeclaration; + + for (uint32_t anchor = 0; anchor < ext_set.num_types(); ++anchor) { + ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); + if (type_record.id.empty()) continue; + + auto ext_decl = std::make_unique(); + + auto type = std::make_unique(); + type->set_extension_uri_reference(map[type_record.id.uri]); + type->set_type_anchor(anchor); + type->set_name(std::string(type_record.id.name)); + ext_decl->set_allocated_extension_type(type.release()); + extensions->AddAllocated(ext_decl.release()); + } + + for (uint32_t anchor = 0; anchor < ext_set.num_functions(); ++anchor) { + ARROW_ASSIGN_OR_RAISE(Id function_id, ext_set.DecodeFunction(anchor)); + + auto fn = std::make_unique(); + fn->set_extension_uri_reference(map[function_id.uri]); + fn->set_function_anchor(anchor); + fn->set_name(std::string(function_id.name)); + + auto ext_decl = std::make_unique(); + ext_decl->set_allocated_extension_function(fn.release()); + extensions->AddAllocated(ext_decl.release()); + } + + return Status::OK(); +} + +std::unique_ptr CreateVersion(); + } // namespace engine } // namespace arrow diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index c05ff422846..8edaa422b3d 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -101,8 +101,8 @@ ARROW_RE2_BUILD_SHA256_CHECKSUM=f89c61410a072e5cbcf8c27e3a778da7d6fd2f2b5b1445cd # 1.1.9 is patched to implement https://github.com/google/snappy/pull/148 if this is bumped, remove the patch ARROW_SNAPPY_BUILD_VERSION=1.1.9 ARROW_SNAPPY_BUILD_SHA256_CHECKSUM=75c1fbb3d618dd3a0483bff0e26d0a92b495bbe5059c8b4f1c962b478b6e06e7 -ARROW_SUBSTRAIT_BUILD_VERSION=v0.20.0 -ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=5ceaa559ccef29a7825b5e5d4b5e7eed384830294f08bec913feecdd903a94cf +ARROW_SUBSTRAIT_BUILD_VERSION=v0.27.0 +ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=4ed375f69d972a57fdc5ec406c17003a111831d8640d3f1733eccd4b3ff45628 ARROW_S2N_TLS_BUILD_VERSION=v1.3.35 ARROW_S2N_TLS_BUILD_SHA256_CHECKSUM=9d32b26e6bfcc058d98248bf8fc231537e347395dd89cf62bb432b55c5da990d ARROW_THRIFT_BUILD_VERSION=0.16.0 diff --git a/docs/source/python/api/substrait.rst b/docs/source/python/api/substrait.rst index 207b2d9cdbc..66e88fcd279 100644 --- a/docs/source/python/api/substrait.rst +++ b/docs/source/python/api/substrait.rst @@ -31,6 +31,19 @@ Query Execution run_query +Expression Serialization +------------------------ + +These functions allow for serialization and deserialization of pyarrow +compute expressions. + +.. autosummary:: + :toctree: ../generated/ + + BoundExpressions + deserialize_expressions + serialize_expressions + Utility ------- diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 453f487c4de..4e96650f152 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -36,6 +36,23 @@ import inspect import numpy as np +__pas = None +_substrait_msg = ( + "The pyarrow installation is not built with support for Substrait." +) + + +def _pas(): + global __pas + if __pas is None: + try: + import pyarrow.substrait as pas + __pas = pas + except ImportError: + raise ImportError(_substrait_msg) + return __pas + + def _forbid_instantiation(klass, subclasses_instead=True): msg = '{} is an abstract class thus cannot be initialized.'.format( klass.__name__ @@ -2364,6 +2381,58 @@ cdef class Expression(_Weakrefable): self.__class__.__name__, str(self) ) + @staticmethod + def from_substrait(object buffer not None): + """ + Deserialize an expression from Substrait + + The serialized message must be an ExtendedExpression message that has + only a single expression. The name of the expression and the schema + the expression was bound to will be ignored. Use + pyarrow.substrait.deserialize_expressions if this information is needed + or if the message might contain multiple expressions. + + Parameters + ---------- + buffer : bytes or Buffer + The Substrait message to deserialize + + Returns + ------- + Expression + The deserialized expression + """ + expressions = _pas().deserialize_expressions(buffer).expressions + if len(expressions) == 0: + raise ValueError("Substrait message did not contain any expressions") + if len(expressions) > 1: + raise ValueError( + "Substrait message contained multiple expressions. Use pyarrow.substrait.deserialize_expressions instead") + return next(iter(expressions.values())) + + def to_substrait(self, Schema schema not None, c_bool allow_arrow_extensions=False): + """ + Serialize the expression using Substrait + + The expression will be serialized as an ExtendedExpression message that has a + single expression named "expression" + + Parameters + ---------- + schema : Schema + The input schema the expression will be bound to + allow_arrow_extensions : bool, default False + If False then only functions that are part of the core Substrait function + definitions will be allowed. Set this to True to allow pyarrow-specific functions + but the result may not be accepted by other compute libraries. + + Returns + ------- + Buffer + A buffer containing the serialized Protobuf plan. + """ + return _pas().serialize_expressions([self], ["expression"], schema, allow_arrow_extensions=allow_arrow_extensions) + @staticmethod def _deserialize(Buffer buffer not None): return Expression.wrap(GetResultValue(CDeserializeExpression( diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 1be2e6330ab..4efad2c4d1b 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -20,6 +20,7 @@ from cython.operator cimport dereference as deref from libcpp.vector cimport vector as std_vector from pyarrow import Buffer, py_buffer +from pyarrow._compute cimport Expression from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * @@ -164,7 +165,7 @@ def _parse_json_plan(plan): Parameters ---------- - plan: bytes + plan : bytes Substrait plan in JSON. Returns @@ -185,6 +186,144 @@ def _parse_json_plan(plan): return pyarrow_wrap_buffer(c_buf_plan) +def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False): + """ + Serialize a collection of expressions into Substrait + + Substrait expressions must be bound to a schema. For example, + the Substrait expression ``a:i32 + b:i32`` is different from the + Substrait expression ``a:i64 + b:i64``. Pyarrow expressions are + typically unbound. For example, both of the above expressions + would be represented as ``a + b`` in pyarrow. + + This means a schema must be provided when serializing an expression. + It also means that the serialization may fail if a matching function + call cannot be found for the expression. + + Parameters + ---------- + exprs : list of Expression + The expressions to serialize + names : list of str + Names for the expressions + schema : Schema + The schema the expressions will be bound to + allow_arrow_extensions : bool, default False + If False then only functions that are part of the core Substrait function + definitions will be allowed. Set this to True to allow pyarrow-specific functions + and user defined functions but the result may not be accepted by other + compute libraries. + + Returns + ------- + Buffer + An ExtendedExpression message containing the serialized expressions + """ + cdef: + CResult[shared_ptr[CBuffer]] c_res_buffer + shared_ptr[CBuffer] c_buffer + CNamedExpression c_named_expr + CBoundExpressions c_bound_exprs + CConversionOptions c_conversion_options + + if len(exprs) != len(names): + raise ValueError("exprs and names need to have the same length") + for expr, name in zip(exprs, names): + if not isinstance(expr, Expression): + raise TypeError(f"Expected Expression, got '{type(expr)}' in exprs") + if not isinstance(name, str): + raise TypeError(f"Expected str, got '{type(name)}' in names") + c_named_expr.expression = ( expr).unwrap() + c_named_expr.name = tobytes( name) + c_bound_exprs.named_expressions.push_back(c_named_expr) + + c_bound_exprs.schema = ( schema).sp_schema + + c_conversion_options.allow_arrow_extensions = allow_arrow_extensions + + with nogil: + c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options) + c_buffer = GetResultValue(c_res_buffer) + return pyarrow_wrap_buffer(c_buffer) + + +cdef class BoundExpressions(_Weakrefable): + """ + A collection of named expressions and the schema they are bound to + + This is equivalent to the Substrait ExtendedExpression message + """ + + cdef: + CBoundExpressions c_bound_exprs + + def __init__(self): + msg = 'BoundExpressions is an abstract class thus cannot be initialized.' + raise TypeError(msg) + + cdef void init(self, CBoundExpressions bound_expressions): + self.c_bound_exprs = bound_expressions + + @property + def schema(self): + """ + The common schema that all expressions are bound to + """ + return pyarrow_wrap_schema(self.c_bound_exprs.schema) + + @property + def expressions(self): + """ + A dict from expression name to expression + """ + expr_dict = {} + for named_expr in self.c_bound_exprs.named_expressions: + name = frombytes(named_expr.name) + expr = Expression.wrap(named_expr.expression) + expr_dict[name] = expr + return expr_dict + + @staticmethod + cdef wrap(const CBoundExpressions& bound_expressions): + cdef BoundExpressions self = BoundExpressions.__new__(BoundExpressions) + self.init(bound_expressions) + return self + + +def deserialize_expressions(buf): + """ + Deserialize an ExtendedExpression Substrait message into a BoundExpressions object + + Parameters + ---------- + buf : Buffer or bytes + The message to deserialize + + Returns + ------- + BoundExpressions + The deserialized expressions, their names, and the bound schema + """ + cdef: + shared_ptr[CBuffer] c_buffer + CResult[CBoundExpressions] c_res_bound_exprs + CBoundExpressions c_bound_exprs + + if isinstance(buf, bytes): + c_buffer = pyarrow_unwrap_buffer(py_buffer(buf)) + elif isinstance(buf, Buffer): + c_buffer = pyarrow_unwrap_buffer(buf) + else: + raise TypeError( + f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'") + + with nogil: + c_res_bound_exprs = DeserializeExpressions(deref(c_buffer)) + c_bound_exprs = GetResultValue(c_res_bound_exprs) + + return BoundExpressions.wrap(c_bound_exprs) + + def get_supported_functions(): """ Get a list of Substrait functions that the underlying diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index eabccb2b4a3..c41f4c05d3a 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -40,6 +40,7 @@ cdef extern from "arrow/engine/substrait/options.h" namespace "arrow::engine" no CConversionOptions() ConversionStrictness strictness function[CNamedTableProvider] named_table_provider + c_bool allow_arrow_extensions cdef extern from "arrow/engine/substrait/extension_set.h" \ namespace "arrow::engine" nogil: @@ -49,6 +50,23 @@ cdef extern from "arrow/engine/substrait/extension_set.h" \ ExtensionIdRegistry* default_extension_id_registry() +cdef extern from "arrow/engine/substrait/relation.h" namespace "arrow::engine" nogil: + + cdef cppclass CNamedExpression "arrow::engine::NamedExpression": + CExpression expression + c_string name + + cdef cppclass CBoundExpressions "arrow::engine::BoundExpressions": + std_vector[CNamedExpression] named_expressions + shared_ptr[CSchema] schema + +cdef extern from "arrow/engine/substrait/serde.h" namespace "arrow::engine" nogil: + + CResult[shared_ptr[CBuffer]] SerializeExpressions( + const CBoundExpressions& bound_expressions, const CConversionOptions& conversion_options) + + CResult[CBoundExpressions] DeserializeExpressions( + const CBuffer& serialized_expressions) cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan( diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py index ea7e19142cd..a2b217f4936 100644 --- a/python/pyarrow/substrait.py +++ b/python/pyarrow/substrait.py @@ -17,8 +17,11 @@ try: from pyarrow._substrait import ( # noqa + BoundExpressions, get_supported_functions, run_query, + deserialize_expressions, + serialize_expressions ) except ImportError as exc: raise ImportError( diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 30a3cdba36d..e41bffbed3c 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -41,6 +41,10 @@ from pyarrow.lib import ArrowNotImplementedError from pyarrow.tests import util +try: + import pyarrow.substrait as pas +except ImportError: + pas = None all_array_types = [ ('bool', [True, False, False, True, True]), @@ -3280,7 +3284,14 @@ def test_rank_options(): tiebreaker="NonExisting") -def test_expression_serialization(): +def create_sample_expressions(): + # We need a schema for substrait conversion + schema = pa.schema([pa.field("i64", pa.int64()), pa.field( + "foo", pa.struct([pa.field("bar", pa.string())]))]) + + # Creates a bunch of sample expressions for testing + # serialization and deserialization. The expressions are categorized + # to reflect certain nuances in Substrait conversion. a = pc.scalar(1) b = pc.scalar(1.1) c = pc.scalar(True) @@ -3289,20 +3300,133 @@ def test_expression_serialization(): f = pc.scalar({'a': 1}) g = pc.scalar(pa.scalar(1)) h = pc.scalar(np.int64(2)) + j = pc.scalar(False) + + # These expression consist entirely of literals + literal_exprs = [a, b, c, d, e, g, h, j] + + # These expressions include at least one function call + exprs_with_call = [a == b, a != b, a > b, c & j, c | j, ~c, d.is_valid(), + a + b, a - b, a * b, a / b, pc.negate(a), + pc.add(a, b), pc.subtract(a, b), pc.divide(a, b), + pc.multiply(a, b), pc.power(a, a), pc.sqrt(a), + pc.exp(b), pc.cos(b), pc.sin(b), pc.tan(b), + pc.acos(b), pc.atan(b), pc.asin(b), pc.atan2(b, b), + pc.abs(b), pc.sign(a), pc.bit_wise_not(a), + pc.bit_wise_and(a, a), pc.bit_wise_or(a, a), + pc.bit_wise_xor(a, a), pc.is_nan(b), pc.is_finite(b), + pc.coalesce(a, b), + a.cast(pa.int32(), safe=False)] + + # These expressions test out various reference styles and may include function + # calls. Named references are used here. + exprs_with_ref = [pc.field('i64') > 5, pc.field('i64') == 5, + pc.field('i64') == 7, + pc.field(('foo', 'bar')) == 'value', + pc.field('foo', 'bar') == 'value'] + + # Similar to above but these use numeric references instead of string refs + exprs_with_numeric_refs = [pc.field(0) > 5, pc.field(0) == 5, + pc.field(0) == 7, + pc.field((1, 0)) == 'value', + pc.field(1, 0) == 'value'] + + # Expressions that behave uniquely when converting to/from substrait + special_cases = [ + f, # Struct literals lose their field names + a.isin([1, 2, 3]), # isin converts to an or list + pc.field('i64').is_null() # pyarrow always specifies a FunctionOptions + # for is_null which, being the default, is + # dropped on serialization + ] + + all_exprs = literal_exprs.copy() + all_exprs += exprs_with_call + all_exprs += exprs_with_ref + all_exprs += special_cases + + return { + "all": all_exprs, + "literals": literal_exprs, + "calls": exprs_with_call, + "refs": exprs_with_ref, + "numeric_refs": exprs_with_numeric_refs, + "special": special_cases, + "schema": schema + } - all_exprs = [a, b, c, d, e, f, g, h, a == b, a > b, a & b, a | b, ~c, - d.is_valid(), a.cast(pa.int32(), safe=False), - a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), - pc.field('i64') > 5, pc.field('i64') == 5, - pc.field('i64') == 7, pc.field('i64').is_null(), - pc.field(('foo', 'bar')) == 'value', - pc.field('foo', 'bar') == 'value'] - for expr in all_exprs: +# Tests the Arrow-specific serialization mechanism + + +def test_expression_serialization_arrow(): + for expr in create_sample_expressions()["all"]: assert isinstance(expr, pc.Expression) restored = pickle.loads(pickle.dumps(expr)) assert expr.equals(restored) +@pytest.mark.substrait +def test_expression_serialization_substrait(): + + exprs = create_sample_expressions() + schema = exprs["schema"] + + # Basic literals don't change on binding and so they will round + # trip without any change + for expr in exprs["literals"]: + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + assert expr.equals(deserialized) + + # Expressions are bound when they get serialized. Since bound + # expressions are not equal to their unbound variants we cannot + # compare the round tripped with the original + for expr in exprs["calls"]: + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + # We can't compare the expressions themselves because of the bound + # unbound difference. But we can compare the string representation + assert str(deserialized) == str(expr) + serialized_again = deserialized.to_substrait(schema) + deserialized_again = pc.Expression.from_substrait(serialized_again) + assert deserialized.equals(deserialized_again) + + for expr, expr_norm in zip(exprs["refs"], exprs["numeric_refs"]): + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + assert str(deserialized) == str(expr_norm) + serialized_again = deserialized.to_substrait(schema) + deserialized_again = pc.Expression.from_substrait(serialized_again) + assert deserialized.equals(deserialized_again) + + # For the special cases we get various wrinkles in serialization but we + # should always get the same thing from round tripping twice + for expr in exprs["special"]: + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + serialized_again = deserialized.to_substrait(schema) + deserialized_again = pc.Expression.from_substrait(serialized_again) + assert deserialized.equals(deserialized_again) + + # Special case, we lose the field names of struct literals + f = exprs["special"][0] + serialized = f.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + assert deserialized.equals(pc.scalar({'': 1})) + + # Special case, is_in converts to a == opt[0] || a == opt[1] ... + a = pc.scalar(1) + expr = a.isin([1, 2, 3]) + target = (a == 1) | (a == 2) | (a == 3) + serialized = expr.to_substrait(schema) + deserialized = pc.Expression.from_substrait(serialized) + # Compare str's here to bypass the bound/unbound difference + assert str(target) == str(deserialized) + serialized_again = deserialized.to_substrait(schema) + deserialized_again = pc.Expression.from_substrait(serialized_again) + assert deserialized.equals(deserialized_again) + + def test_expression_construction(): zero = pc.scalar(0) one = pc.scalar(1) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 93ecae7bfa1..be35a21a024 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -21,8 +21,10 @@ import pytest import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.dataset as ds from pyarrow.lib import tobytes -from pyarrow.lib import ArrowInvalid +from pyarrow.lib import ArrowInvalid, ArrowNotImplementedError try: import pyarrow.substrait as substrait @@ -923,3 +925,106 @@ def table_provider(names, _): # Ordering of k is deterministic because this is running with serial execution assert res_tb == expected_tb + + +@pytest.mark.parametrize("expr", [ + pc.equal(ds.field("x"), 7), + pc.equal(ds.field("x"), ds.field("y")), + ds.field("x") > 50 +]) +def test_serializing_expressions(expr): + schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + + buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + assert len(returned.expressions) == 1 + assert "test_expr" in returned.expressions + + +def test_invalid_expression_ser_des(): + schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + expr = pc.equal(ds.field("x"), 7) + bad_expr = pc.equal(ds.field("z"), 7) + # Invalid number of names + with pytest.raises(ValueError) as excinfo: + pa.substrait.serialize_expressions([expr], [], schema) + assert 'need to have the same length' in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + pa.substrait.serialize_expressions([expr], ["foo", "bar"], schema) + assert 'need to have the same length' in str(excinfo.value) + # Expression doesn't match schema + with pytest.raises(ValueError) as excinfo: + pa.substrait.serialize_expressions([bad_expr], ["expr"], schema) + assert 'No match for FieldRef' in str(excinfo.value) + + +def test_serializing_multiple_expressions(): + schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + exprs = [pc.equal(ds.field("x"), 7), pc.equal(ds.field("x"), ds.field("y"))] + buf = pa.substrait.serialize_expressions(exprs, ["first", "second"], schema) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + assert len(returned.expressions) == 2 + + norm_exprs = [pc.equal(ds.field(0), 7), pc.equal(ds.field(0), ds.field(1))] + assert str(returned.expressions["first"]) == str(norm_exprs[0]) + assert str(returned.expressions["second"]) == str(norm_exprs[1]) + + +def test_serializing_with_compute(): + schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) + ]) + expr = pc.equal(ds.field("x"), 7) + expr_norm = pc.equal(ds.field(0), 7) + buf = expr.to_substrait(schema) + returned = pa.substrait.deserialize_expressions(buf) + + assert schema == returned.schema + assert len(returned.expressions) == 1 + + assert str(returned.expressions["expression"]) == str(expr_norm) + + # Compute can't deserialize messages with multiple expressions + buf = pa.substrait.serialize_expressions([expr, expr], ["first", "second"], schema) + with pytest.raises(ValueError) as excinfo: + pc.Expression.from_substrait(buf) + assert 'contained multiple expressions' in str(excinfo.value) + + # Deserialization should be possible regardless of the expression name + buf = pa.substrait.serialize_expressions([expr], ["weirdname"], schema) + expr2 = pc.Expression.from_substrait(buf) + assert str(expr2) == str(expr_norm) + + +def test_serializing_udfs(): + # Note, UDF in this context means a function that is not + # recognized by Substrait. It might still be a builtin pyarrow + # function. + schema = pa.schema([ + pa.field("x", pa.uint32()) + ]) + a = pc.scalar(10) + b = pc.scalar(4) + exprs = [pc.shift_left(a, b)] + + with pytest.raises(ArrowNotImplementedError): + pa.substrait.serialize_expressions(exprs, ["expr"], schema) + + buf = pa.substrait.serialize_expressions( + exprs, ["expr"], schema, allow_arrow_extensions=True) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + assert len(returned.expressions) == 1 + assert str(returned.expressions["expr"]) == str(exprs[0])