Skip to content

Commit

Permalink
add is_impure property for the function
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangHuiGui committed Mar 12, 2024
1 parent 1cc218a commit 57d1ebc
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 30 deletions.
10 changes: 4 additions & 6 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -845,8 +845,8 @@ Result<Expression> FoldConstants(Expression expr) {
std::move(expr), [](Expression expr) { return expr; },
[](Expression expr, ...) -> Result<Expression> {
auto call = CallNotNull(expr);
if (!call->arguments.empty() &&
std::all_of(call->arguments.begin(), call->arguments.end(),
if (!call->function->is_impure()) return expr;
if (std::all_of(call->arguments.begin(), call->arguments.end(),
[](const Expression& argument) { return argument.literal(); })) {
// all arguments are literal; we can evaluate this subexpression *now*
static const ExecBatch ignored_input = ExecBatch({}, 1);
Expand All @@ -862,10 +862,6 @@ Result<Expression> FoldConstants(Expression expr) {
if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) {
// kernels which always produce intersected validity can be resolved
// to null *now* if any of their inputs is a null literal
if (!call->type.type) {
return Status::Invalid("Cannot fold constants for unbound expression ",
expr.ToString());
}
for (const Expression& argument : call->arguments) {
if (argument.IsNullLiteral()) {
if (argument.type()->Equals(*call->type.type)) {
Expand Down Expand Up @@ -1088,6 +1084,7 @@ Result<Expression> Canonicalize(Expression expr, compute::ExecContext* exec_cont
[&AlreadyCanonicalized, exec_context](Expression expr) -> Result<Expression> {
auto call = expr.call();
if (!call) return expr;
if (!call->function->is_impure()) return expr;

if (AlreadyCanonicalized(expr)) return expr;

Expand Down Expand Up @@ -1334,6 +1331,7 @@ Result<Expression> SimplifyIsValidGuarantee(Expression expr,
[&](Expression expr, ...) -> Result<Expression> {
auto call = expr.call();
if (!call) return expr;
if (!call->function->is_impure()) return expr;

if (call->arguments[0] != guarantee.arguments[0]) return expr;

Expand Down
50 changes: 31 additions & 19 deletions cpp/src/arrow/compute/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,21 @@ class ARROW_EXPORT Function {

virtual Status Validate() const;

/// \brief Returns the pure property for this function.
///
/// For impure functions like 'random', we should skip any simplification
/// for this function except it's arguments.
bool is_impure() const { return is_impure_; }

protected:
Function(std::string name, Function::Kind kind, const Arity& arity, FunctionDoc doc,
const FunctionOptions* default_options)
const FunctionOptions* default_options, bool is_impure)
: name_(std::move(name)),
kind_(kind),
arity_(arity),
doc_(std::move(doc)),
default_options_(default_options) {}
default_options_(default_options),
is_impure_(is_impure) {}

Status CheckArity(size_t num_args) const;

Expand All @@ -245,6 +252,8 @@ class ARROW_EXPORT Function {
Arity arity_;
const FunctionDoc doc_;
const FunctionOptions* default_options_ = NULLPTR;

bool is_impure_ = false;
};

namespace detail {
Expand All @@ -265,8 +274,9 @@ class FunctionImpl : public Function {

protected:
FunctionImpl(std::string name, Function::Kind kind, const Arity& arity, FunctionDoc doc,
const FunctionOptions* default_options)
: Function(std::move(name), kind, arity, std::move(doc), default_options) {}
const FunctionOptions* default_options, bool is_impure)
: Function(std::move(name), kind, arity, std::move(doc), default_options,
is_impure) {}

std::vector<KernelType> kernels_;
};
Expand All @@ -291,9 +301,9 @@ class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl<ScalarKernel> {
using KernelType = ScalarKernel;

ScalarFunction(std::string name, const Arity& arity, FunctionDoc doc,
const FunctionOptions* default_options = NULLPTR)
const FunctionOptions* default_options = NULLPTR, bool is_impure = false)
: detail::FunctionImpl<ScalarKernel>(std::move(name), Function::SCALAR, arity,
std::move(doc), default_options) {}
std::move(doc), default_options, is_impure) {}

/// \brief Add a kernel with given input/output types, no required state
/// initialization, preallocation for fixed-width types, and default null
Expand All @@ -315,9 +325,9 @@ class ARROW_EXPORT VectorFunction : public detail::FunctionImpl<VectorKernel> {
using KernelType = VectorKernel;

VectorFunction(std::string name, const Arity& arity, FunctionDoc doc,
const FunctionOptions* default_options = NULLPTR)
const FunctionOptions* default_options = NULLPTR, bool is_impure = false)
: detail::FunctionImpl<VectorKernel>(std::move(name), Function::VECTOR, arity,
std::move(doc), default_options) {}
std::move(doc), default_options, is_impure) {}

/// \brief Add a simple kernel with given input/output types, no required
/// state initialization, no data preallocation, and no preallocation of the
Expand All @@ -336,10 +346,11 @@ class ARROW_EXPORT ScalarAggregateFunction
using KernelType = ScalarAggregateKernel;

ScalarAggregateFunction(std::string name, const Arity& arity, FunctionDoc doc,
const FunctionOptions* default_options = NULLPTR)
: detail::FunctionImpl<ScalarAggregateKernel>(std::move(name),
Function::SCALAR_AGGREGATE, arity,
std::move(doc), default_options) {}
const FunctionOptions* default_options = NULLPTR,
bool is_impure = false)
: detail::FunctionImpl<ScalarAggregateKernel>(
std::move(name), Function::SCALAR_AGGREGATE, arity, std::move(doc),
default_options, is_impure) {}

/// \brief Add a kernel (function implementation). Returns error if the
/// kernel's signature does not match the function's arity.
Expand All @@ -352,10 +363,11 @@ class ARROW_EXPORT HashAggregateFunction
using KernelType = HashAggregateKernel;

HashAggregateFunction(std::string name, const Arity& arity, FunctionDoc doc,
const FunctionOptions* default_options = NULLPTR)
: detail::FunctionImpl<HashAggregateKernel>(std::move(name),
Function::HASH_AGGREGATE, arity,
std::move(doc), default_options) {}
const FunctionOptions* default_options = NULLPTR,
bool is_impure = false)
: detail::FunctionImpl<HashAggregateKernel>(
std::move(name), Function::HASH_AGGREGATE, arity, std::move(doc),
default_options, is_impure) {}

/// \brief Add a kernel (function implementation). Returns error if the
/// kernel's signature does not match the function's arity.
Expand Down Expand Up @@ -383,9 +395,9 @@ class ARROW_EXPORT MetaFunction : public Function {
ExecContext* ctx) const = 0;

MetaFunction(std::string name, const Arity& arity, FunctionDoc doc,
const FunctionOptions* default_options = NULLPTR)
: Function(std::move(name), Function::META, arity, std::move(doc),
default_options) {}
const FunctionOptions* default_options = NULLPTR, bool is_impure = false)
: Function(std::move(name), Function::META, arity, std::move(doc), default_options,
is_impure) {}
};

/// @}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,8 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
static auto default_count_options = CountOptions::Defaults();

auto func = std::make_shared<ScalarAggregateFunction>("count_all", Arity::Nullary(),
count_all_doc, NULLPTR);
auto func = std::make_shared<ScalarAggregateFunction>(
"count_all", Arity::Nullary(), count_all_doc, NULLPTR, /*is_impure=*/true);

// Takes no input (counts all rows), outputs int64 scalar
AddAggKernel(KernelSignature::Make({}, int64()), CountAllInit, func.get());
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ const FunctionDoc random_doc{

void RegisterScalarRandom(FunctionRegistry* registry) {
static auto random_options = RandomOptions::Defaults();
auto random_func = std::make_shared<ScalarFunction>("random", Arity::Nullary(),
random_doc, &random_options);
auto random_func = std::make_shared<ScalarFunction>(
"random", Arity::Nullary(), random_doc, &random_options, /*is_impure=*/true);
ScalarKernel kernel{{}, float64(), ExecRandom, RandomState::Init};
kernel.null_handling = NullHandling::OUTPUT_NOT_NULL;
DCHECK_OK(random_func->AddKernel(kernel));
Expand Down
2 changes: 1 addition & 1 deletion docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ equivalents above and reflects how they are implemented internally.
+-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
| hash_count | Unary | Any | Int64 | :struct:`CountOptions` | \(2) |
+-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
| hash_count_all | Nullary | | Int64 | | |
| hash_count_all | Unary | | Int64 | | |
+-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
| hash_count_distinct | Unary | Any | Int64 | :struct:`CountOptions` | \(2) |
+-------------------------+---------+------------------------------------+------------------------+----------------------------------+-----------+
Expand Down

0 comments on commit 57d1ebc

Please sign in to comment.