Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cpp/cmake_modules/ThirdpartyToolchain.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1785,7 +1785,12 @@ macro(build_substrait)

# Note: not all protos in Substrait actually matter to plan
# consumption. No need to build the ones we don't need.
set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type)
set(SUBSTRAIT_PROTOS
algebra
extended_expression
extensions/extensions
plan
type)
set(ARROW_SUBSTRAIT_PROTOS extension_rels)
set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto")

Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/compute/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,16 @@ CastOptions::CastOptions(bool safe)
allow_float_truncate(!safe),
allow_invalid_utf8(!safe) {}

bool CastOptions::is_safe() const {
return !allow_int_overflow && !allow_time_truncate && !allow_time_overflow &&
!allow_decimal_truncate && !allow_float_truncate && !allow_invalid_utf8;
}

bool CastOptions::is_unsafe() const {
return allow_int_overflow && allow_time_truncate && allow_time_overflow &&
allow_decimal_truncate && allow_float_truncate && allow_invalid_utf8;
}

constexpr char CastOptions::kTypeName[];

Result<Datum> Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) {
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ class ARROW_EXPORT CastOptions : public FunctionOptions {
// Indicate if conversions from Binary/FixedSizeBinary to string must
// validate the utf8 payload.
bool allow_invalid_utf8;

/// true if the safety options all match CastOptions::Safe
///
/// Note, if this returns false it does not mean is_unsafe will return true
bool is_safe() const;
/// true if the safety options all match CastOptions::Unsafe
///
/// Note, if this returns false it does not mean is_safe will return true
bool is_unsafe() const;
};

/// @}
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ arrow_install_all_headers("arrow/engine")

set(ARROW_SUBSTRAIT_SRCS
substrait/expression_internal.cc
substrait/extended_expression_internal.cc
substrait/extension_set.cc
substrait/extension_types.cc
substrait/options.cc
Expand Down
162 changes: 156 additions & 6 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "arrow/buffer.h"
#include "arrow/builder.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/expression.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/engine/substrait/extension_set.h"
Expand Down Expand Up @@ -338,6 +339,22 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
return compute::call("case_when", std::move(args));
}

case substrait::Expression::kSingularOrList: {
const auto& or_list = expr.singular_or_list();

ARROW_ASSIGN_OR_RAISE(compute::Expression value,
FromProto(or_list.value(), ext_set, conversion_options));

std::vector<compute::Expression> option_eqs;
for (const auto& option : or_list.options()) {
ARROW_ASSIGN_OR_RAISE(compute::Expression arrow_option,
FromProto(option, ext_set, conversion_options));
option_eqs.push_back(compute::call("equal", {value, arrow_option}));
}

return compute::or_(option_eqs);
}

case substrait::Expression::kScalarFunction: {
const auto& scalar_fn = expr.scalar_function();

Expand Down Expand Up @@ -1055,9 +1072,68 @@ Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCa
" arguments but no argument could be found at index ", i);
}
}

for (const auto& option : call.options()) {
substrait::FunctionOption* fn_option = scalar_fn->add_options();
fn_option->set_name(option.first);
for (const auto& opt_val : option.second) {
std::string* pref = fn_option->add_preference();
*pref = opt_val;
}
}

return std::move(scalar_fn);
}

Result<std::vector<std::unique_ptr<substrait::Expression>>> DatumToLiterals(
const Datum& datum, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
std::vector<std::unique_ptr<substrait::Expression>> literals;

auto ScalarToLiteralExpr = [&](const std::shared_ptr<Scalar>& scalar)
-> Result<std::unique_ptr<substrait::Expression>> {
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression::Literal> literal,
ToProto(scalar, ext_set, conversion_options));
auto literal_expr = std::make_unique<substrait::Expression>();
literal_expr->set_allocated_literal(literal.release());
return literal_expr;
};

switch (datum.kind()) {
case Datum::Kind::SCALAR: {
ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(datum.scalar()));
literals.push_back(std::move(literal_expr));
break;
}
case Datum::Kind::ARRAY: {
std::shared_ptr<Array> values = datum.make_array();
for (int64_t i = 0; i < values->length(); i++) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, values->GetScalar(i));
ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(scalar));
literals.push_back(std::move(literal_expr));
}
break;
}
case Datum::Kind::CHUNKED_ARRAY: {
std::shared_ptr<ChunkedArray> values = datum.chunked_array();
for (const auto& chunk : values->chunks()) {
for (int64_t i = 0; i < chunk->length(); i++) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, chunk->GetScalar(i));
ARROW_ASSIGN_OR_RAISE(auto literal_expr, ScalarToLiteralExpr(scalar));
literals.push_back(std::move(literal_expr));
}
}
break;
}
case Datum::Kind::RECORD_BATCH:
case Datum::Kind::TABLE:
case Datum::Kind::NONE:
return Status::Invalid("Expected a literal or an array of literals, got ",
datum.ToString());
}
return literals;
}

Result<std::unique_ptr<substrait::Expression>> ToProto(
const compute::Expression& expr, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
Expand Down Expand Up @@ -1164,15 +1240,89 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(

out->set_allocated_if_then(if_then.release());
return std::move(out);
} else if (call->function_name == "cast") {
auto cast = std::make_unique<substrait::Expression::Cast>();

// Arrow's cast function does not have a "return null" option and so throw exception
// is the only behavior we can support.
cast->set_failure_behavior(
substrait::Expression::Cast::FAILURE_BEHAVIOR_THROW_EXCEPTION);

std::shared_ptr<compute::CastOptions> cast_options =
internal::checked_pointer_cast<compute::CastOptions>(call->options);
if (!cast_options->is_unsafe()) {
return Status::Invalid("Substrait is only capable of representing unsafe casts");
}

if (arguments.size() != 1) {
return Status::Invalid(
"A call to the cast function must have exactly one argument");
}

cast->set_allocated_input(arguments[0].release());

ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Type> to_type,
ToProto(*cast_options->to_type.type, /*nullable=*/true, ext_set,
conversion_options));

cast->set_allocated_type(to_type.release());

out->set_allocated_cast(cast.release());
return std::move(out);
} else if (call->function_name == "is_in") {
auto or_list = std::make_unique<substrait::Expression::SingularOrList>();

if (arguments.size() != 1) {
return Status::Invalid(
"A call to the is_in function must have exactly one argument");
}

or_list->set_allocated_value(arguments[0].release());
std::shared_ptr<compute::SetLookupOptions> is_in_options =
internal::checked_pointer_cast<compute::SetLookupOptions>(call->options);

// TODO(GH-36420) Acero does not currently handle nulls correctly
ARROW_ASSIGN_OR_RAISE(
std::vector<std::unique_ptr<substrait::Expression>> options,
DatumToLiterals(is_in_options->value_set, ext_set, conversion_options));
for (auto& option : options) {
or_list->mutable_options()->AddAllocated(option.release());
}
out->set_allocated_singular_or_list(or_list.release());
return std::move(out);
}

// other expression types dive into extensions immediately
ARROW_ASSIGN_OR_RAISE(
ExtensionIdRegistry::ArrowToSubstraitCall converter,
ext_set->registry()->GetArrowToSubstraitCall(call->function_name));
ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression::ScalarFunction> scalar_fn,
EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
Result<ExtensionIdRegistry::ArrowToSubstraitCall> maybe_converter =
ext_set->registry()->GetArrowToSubstraitCall(call->function_name);

ExtensionIdRegistry::ArrowToSubstraitCall converter;
std::unique_ptr<substrait::Expression::ScalarFunction> scalar_fn;
if (maybe_converter.ok()) {
converter = *maybe_converter;
ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
ARROW_ASSIGN_OR_RAISE(
scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
} else if (maybe_converter.status().IsNotImplemented() &&
conversion_options.allow_arrow_extensions) {
if (call->options) {
return Status::NotImplemented(
"The function ", call->function_name,
" has no Substrait mapping. Arrow extensions are enabled but the call "
"contains function options and there is no current mechanism to encode those.");
}
Id persistent_id = ext_set->RegisterPlanSpecificId(
{kArrowSimpleExtensionFunctionsUri, call->function_name});
SubstraitCall substrait_call(persistent_id, call->type.GetSharedPtr(),
/*nullable=*/true);
for (int i = 0; i < static_cast<int>(call->arguments.size()); i++) {
substrait_call.SetValueArg(i, call->arguments[i]);
}
ARROW_ASSIGN_OR_RAISE(
scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
} else {
return maybe_converter.status();
}
out->set_allocated_scalar_function(scalar_fn.release());
return std::move(out);
}
Expand Down
Loading