Skip to content

Commit

Permalink
fix Cast function bind failed after add a alias name through AddAlias
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangHuiGui committed Mar 1, 2024
1 parent 7c4f4c2 commit ea9a4c1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
7 changes: 5 additions & 2 deletions cpp/src/arrow/compute/expression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,11 @@ struct FlattenedAssociativeChain {

inline Result<std::shared_ptr<compute::Function>> GetFunction(
const Expression::Call& call, compute::ExecContext* exec_context) {
if (call.function_name != "cast") {
return exec_context->func_registry()->GetFunction(call.function_name);
auto input_call_function =
exec_context->func_registry()->GetFunction(call.function_name);
auto cast_function = exec_context->func_registry()->GetFunction("cast");
if (input_call_function != cast_function) {
return input_call_function;
}
// XXX this special case is strange; why not make "cast" a ScalarFunction?
const TypeHolder& to_type =
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,17 @@ TEST(Expression, BindCall) {
add(cast(field_ref("i32"), float32()), literal(3.5F)));
}

TEST(Expression, BindWithAliasCasts) {
auto fm = GetFunctionRegistry();
EXPECT_OK(fm->AddAlias("alias_cast", "cast"));

auto expr = call("alias_cast", {field_ref("f1")}, CastOptions::Unsafe(arrow::int32()));
EXPECT_FALSE(expr.IsBound());

auto schema = arrow::schema({field("f1", decimal128(30, 3))});
ExpectBindsTo(expr, no_change, &expr, *schema);
}

TEST(Expression, BindWithDecimalArithmeticOps) {
for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) {
auto expr = call(arith_op, {field_ref("d1"), field_ref("d2")});
Expand Down

0 comments on commit ea9a4c1

Please sign in to comment.