Skip to content

Commit

Permalink
Add comments for the implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 11, 2024
1 parent b951348 commit 520b952
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 25 deletions.
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ Result<std::shared_ptr<Array>> ReverseIndices(
///
/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear
/// in the input indices. For indices[i] = x where x < 0 or x >= output_length, values[i]
/// is ignored.
/// is ignored. If multiple indices point to the same value, the last one is used.
///
/// For example, with values = [a, b, c, d, e, f, g] and indices = [null, 0,
/// 3, 2, 4, 1, 1], the permutation is
Expand Down
75 changes: 52 additions & 23 deletions cpp/src/arrow/compute/kernels/vector_placement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,23 @@ const ReverseIndicesOptions* GetDefaultReverseIndicesOptions() {

using ReverseIndicesState = OptionsWrapper<ReverseIndicesOptions>;

/// Resolve the output type of reverse_indices. The output type is specified in the
/// options, and if null, set it to the input type. The output type must be integer.
Result<TypeHolder> ResolveReverseIndicesOutputType(
KernelContext* ctx, const std::vector<TypeHolder>& input_types) {
DCHECK_EQ(input_types.size(), 1);
DCHECK_NE(input_types[0], nullptr);

const DataType* output_type = ReverseIndicesState::Get(ctx).output_type.get();
std::shared_ptr<DataType> output_type = ReverseIndicesState::Get(ctx).output_type;
if (!output_type) {
output_type = input_types[0].type;
output_type = input_types[0].owned_type;
}
if (!is_integer(output_type->id())) {
return Status::Invalid("Output type of reverse_indices must be integer, got " +
output_type->ToString());
}

return TypeHolder(output_type->GetSharedPtr());
return TypeHolder(std::move(output_type));
}

template <typename ExecType>
Expand All @@ -69,6 +71,7 @@ struct ReverseIndicesImpl {
const std::shared_ptr<DataType>& input_type) {
const auto& options = ReverseIndicesState::Get(ctx);

// Apply default options semantics.
int64_t output_length = options.output_length;
if (output_length < 0) {
output_length = input_length;
Expand Down Expand Up @@ -113,7 +116,16 @@ struct ReverseIndicesImpl {

RETURN_NOT_OK(CheckInput(output_type));

bool likely_many_nulls = IsLikelyManyNulls();
// Dispatch the execution based on wether there are likely many nulls in the output.
// - If many nulls (i.e. the output is "sparse"), preallocate an all-false validity
// buffer and an uninitialized data buffer. The subsequent processing will fill the
// valid values only.
// - Otherwise (i.e. the output is "dense"), the validity buffer is lazily allocated
// and filled all-true in the subsequent processing only when needed. The data buffer
// is preallocated and filled with an "impossible" value (input_length - note that the
// range of reverse_indices is [0, input_length)) for the subsequent processing to
// detect validity.
bool likely_many_nulls = LikelyManyNulls();
if (likely_many_nulls) {
RETURN_NOT_OK(AllocateValidityBufAndFill(false));
RETURN_NOT_OK(AllocateDataBuf(output_type));
Expand Down Expand Up @@ -142,7 +154,7 @@ struct ReverseIndicesImpl {
return Status::OK();
}

bool IsLikelyManyNulls() { return output_length > 2 * input_length; }
bool LikelyManyNulls() { return output_length > 2 * input_length; }

Status AllocateValidityBufAndFill(bool valid) {
DCHECK_EQ(validity_buf, nullptr);
Expand Down Expand Up @@ -195,6 +207,7 @@ struct ReverseIndicesImpl {
if (ARROW_PREDICT_TRUE(index >= 0 &&
static_cast<int64_t>(index) < output_length)) {
data[index] = static_cast<OutputCType>(reverse_index);
// If many nulls, set validity to true for valid values.
if constexpr (likely_many_nulls) {
bit_util::SetBitTo(validity, index, true);
}
Expand All @@ -207,9 +220,12 @@ struct ReverseIndicesImpl {
return Status::OK();
}));

// If not many nulls, run another pass iterating over the data to set the validity to
// false if the value is "impossible". The validity buffer is allocated and filled
// all-true on-the-fly when the first "impossible" value is seen.
if constexpr (!likely_many_nulls) {
for (int64_t i = 0; i < output_length; ++i) {
if (data[i] == static_cast<OutputCType>(input_length)) {
if (ARROW_PREDICT_FALSE(data[i] == static_cast<OutputCType>(input_length))) {
if (ARROW_PREDICT_FALSE(!validity_buf)) {
RETURN_NOT_OK(AllocateValidityBufAndFill(true));
validity = validity_buf->mutable_data_as<uint8_t>();
Expand Down Expand Up @@ -240,12 +256,12 @@ struct ReverseIndices {
}

static Status Exec(KernelContext* ctx, const ExecSpan& span, ExecResult* result) {
DCHECK_EQ(span.num_values(), 1);
DCHECK(span[0].is_array());
const auto& indices = span[0].array;
ARROW_ASSIGN_OR_RAISE(
auto output, ReverseIndicesImpl<ThisType>::Exec(ctx, indices, indices.length,
indices.type->GetSharedPtr()));
result->value = std::move(output);
result->value, ReverseIndicesImpl<ThisType>::Exec(ctx, indices, indices.length,
indices.type->GetSharedPtr()));
return Status::OK();
}
};
Expand All @@ -268,12 +284,11 @@ struct ReverseIndicesChunked {
}

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* result) {
DCHECK_EQ(batch.num_values(), 1);
DCHECK(batch[0].is_chunked_array());
const auto& indices = batch[0].chunked_array();
ARROW_ASSIGN_OR_RAISE(
auto output, ReverseIndicesImpl<ThisType>::Exec(ctx, indices, indices->length(),
indices->type()));
*result = std::move(output);
ARROW_ASSIGN_OR_RAISE(*result, ReverseIndicesImpl<ThisType>::Exec(
ctx, indices, indices->length(), indices->type()));
return Status::OK();
}
};
Expand All @@ -295,8 +310,8 @@ void RegisterVectorReverseIndices(FunctionRegistry* registry) {
kernel.output_chunked = false;
DCHECK_OK(function->AddKernel(std::move(kernel)));
};
for (const auto& ty : IntTypes()) {
add_kernel(ty->id());
for (const auto& t : IntTypes()) {
add_kernel(t->id());
}

DCHECK_OK(registry->AddFunction(std::move(function)));
Expand Down Expand Up @@ -324,6 +339,7 @@ class PermuteMetaFunction : public MetaFunction {
Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
const FunctionOptions* options,
ExecContext* ctx) const override {
DCHECK_EQ(args.size(), 2);
const auto& values = args[0];
const auto& indices = args[1];
auto* permute_options = checked_cast<const PermuteOptions*>(options);
Expand All @@ -336,12 +352,25 @@ class PermuteMetaFunction : public MetaFunction {
return Status::Invalid("Indices of permute must be of integer type, got ",
indices.type()->ToString());
}
// Apply default options semantics.
int64_t output_length = permute_options->output_length;
if (output_length < 0) {
output_length = values.length();
}
// Internally invoke take(values, reverse_indices(indices)) to implement permute.
// For example, with
// values = [a, b, c, d, e, f, g]
// indices = [null, 0, 3, 2, 4, 1, 1]
// the reverse_indices(indices) is
// [1, 6, 3] if output_length = 3,
// [1, 6, 3, 2, 4, null, null] if output_length = 7.
// and take(values, reverse_indices(indices)) is
// [b, g, d] if output_length = 3,
// [b, g, d, c, e, null, null] if output_length = 7.
ReverseIndicesOptions reverse_indices_options{
output_length, InferSmallestReverseIndicesType(values.length())};
output_length,
// Use the smallest possible uint type to store reverse indices.
InferSmallestReverseIndicesType(values.length())};
ARROW_ASSIGN_OR_RAISE(
auto reverse_indices,
CallFunction("reverse_indices", {indices}, &reverse_indices_options, ctx));
Expand All @@ -351,14 +380,14 @@ class PermuteMetaFunction : public MetaFunction {

private:
static std::shared_ptr<DataType> InferSmallestReverseIndicesType(int64_t input_length) {
if (input_length <= std::numeric_limits<int8_t>::max()) {
return int8();
} else if (input_length <= std::numeric_limits<int16_t>::max()) {
return int16();
} else if (input_length <= std::numeric_limits<int32_t>::max()) {
return int32();
if (input_length <= std::numeric_limits<uint8_t>::max()) {
return uint8();
} else if (input_length <= std::numeric_limits<uint16_t>::max()) {
return uint16();
} else if (input_length <= std::numeric_limits<uint32_t>::max()) {
return uint32();
} else {
return int64();
return uint64();
}
}
};
Expand Down
3 changes: 2 additions & 1 deletion docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1922,4 +1922,5 @@ indices.

* \(3) For ``indices[i] = x``, ``output[x] = values[i]``. And ``output[x] = null``
if ``x`` does not appear in the input ``indices``. For ``indices[i] = x`` where
``x < 0`` or ``x >= output_length``, ``values[i]`` is ignored.
``x < 0`` or ``x >= output_length``, ``values[i]`` is ignored. If multiple indices
point to the same value, the last one is used.

0 comments on commit 520b952

Please sign in to comment.