diff --git a/cpp/src/arrow/compute/expression_internal.h b/cpp/src/arrow/compute/expression_internal.h index 083756dc5fd39..a34fbf62afbc1 100644 --- a/cpp/src/arrow/compute/expression_internal.h +++ b/cpp/src/arrow/compute/expression_internal.h @@ -278,8 +278,11 @@ struct FlattenedAssociativeChain { inline Result> GetFunction( const Expression::Call& call, compute::ExecContext* exec_context) { - if (call.function_name != "cast") { - return exec_context->func_registry()->GetFunction(call.function_name); + auto input_call_function = + exec_context->func_registry()->GetFunction(call.function_name); + auto cast_function = exec_context->func_registry()->GetFunction("cast"); + if (input_call_function != cast_function) { + return input_call_function; } // XXX this special case is strange; why not make "cast" a ScalarFunction? const TypeHolder& to_type = diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index 38f8183dabcba..0c3403d3d5fe0 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -604,6 +604,17 @@ TEST(Expression, BindCall) { add(cast(field_ref("i32"), float32()), literal(3.5F))); } +TEST(Expression, BindWithAliasCasts) { + auto fm = GetFunctionRegistry(); + EXPECT_OK(fm->AddAlias("alias_cast", "cast")); + + auto expr = call("alias_cast", {field_ref("f1")}, CastOptions::Unsafe(arrow::int32())); + EXPECT_FALSE(expr.IsBound()); + + auto schema = arrow::schema({field("f1", decimal128(30, 3))}); + ExpectBindsTo(expr, no_change, &expr, *schema); +} + TEST(Expression, BindWithDecimalArithmeticOps) { for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) { auto expr = call(arith_op, {field_ref("d1"), field_ref("d2")});