diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 3d59b7627fd1b..a31f855a3f891 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -764,7 +764,7 @@ Result> ReverseIndices( /// /// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear /// in the input indices. For indices[i] = x where x < 0 or x >= output_length, values[i] -/// is ignored. +/// is ignored. If multiple indices point to the same value, the last one is used. /// /// For example, with values = [a, b, c, d, e, f, g] and indices = [null, 0, /// 3, 2, 4, 1, 1], the permutation is diff --git a/cpp/src/arrow/compute/kernels/vector_placement.cc b/cpp/src/arrow/compute/kernels/vector_placement.cc index 199fa02c0b61d..4ca1288e98140 100644 --- a/cpp/src/arrow/compute/kernels/vector_placement.cc +++ b/cpp/src/arrow/compute/kernels/vector_placement.cc @@ -40,21 +40,23 @@ const ReverseIndicesOptions* GetDefaultReverseIndicesOptions() { using ReverseIndicesState = OptionsWrapper; +/// Resolve the output type of reverse_indices. The output type is specified in the +/// options, and if null, set it to the input type. The output type must be integer. Result ResolveReverseIndicesOutputType( KernelContext* ctx, const std::vector& input_types) { DCHECK_EQ(input_types.size(), 1); DCHECK_NE(input_types[0], nullptr); - const DataType* output_type = ReverseIndicesState::Get(ctx).output_type.get(); + std::shared_ptr output_type = ReverseIndicesState::Get(ctx).output_type; if (!output_type) { - output_type = input_types[0].type; + output_type = input_types[0].owned_type; } if (!is_integer(output_type->id())) { return Status::Invalid("Output type of reverse_indices must be integer, got " + output_type->ToString()); } - return TypeHolder(output_type->GetSharedPtr()); + return TypeHolder(std::move(output_type)); } template @@ -69,6 +71,7 @@ struct ReverseIndicesImpl { const std::shared_ptr& input_type) { const auto& options = ReverseIndicesState::Get(ctx); + // Apply default options semantics. int64_t output_length = options.output_length; if (output_length < 0) { output_length = input_length; @@ -113,7 +116,16 @@ struct ReverseIndicesImpl { RETURN_NOT_OK(CheckInput(output_type)); - bool likely_many_nulls = IsLikelyManyNulls(); + // Dispatch the execution based on wether there are likely many nulls in the output. + // - If many nulls (i.e. the output is "sparse"), preallocate an all-false validity + // buffer and an uninitialized data buffer. The subsequent processing will fill the + // valid values only. + // - Otherwise (i.e. the output is "dense"), the validity buffer is lazily allocated + // and filled all-true in the subsequent processing only when needed. The data buffer + // is preallocated and filled with an "impossible" value (input_length - note that the + // range of reverse_indices is [0, input_length)) for the subsequent processing to + // detect validity. + bool likely_many_nulls = LikelyManyNulls(); if (likely_many_nulls) { RETURN_NOT_OK(AllocateValidityBufAndFill(false)); RETURN_NOT_OK(AllocateDataBuf(output_type)); @@ -142,7 +154,7 @@ struct ReverseIndicesImpl { return Status::OK(); } - bool IsLikelyManyNulls() { return output_length > 2 * input_length; } + bool LikelyManyNulls() { return output_length > 2 * input_length; } Status AllocateValidityBufAndFill(bool valid) { DCHECK_EQ(validity_buf, nullptr); @@ -195,6 +207,7 @@ struct ReverseIndicesImpl { if (ARROW_PREDICT_TRUE(index >= 0 && static_cast(index) < output_length)) { data[index] = static_cast(reverse_index); + // If many nulls, set validity to true for valid values. if constexpr (likely_many_nulls) { bit_util::SetBitTo(validity, index, true); } @@ -207,9 +220,12 @@ struct ReverseIndicesImpl { return Status::OK(); })); + // If not many nulls, run another pass iterating over the data to set the validity to + // false if the value is "impossible". The validity buffer is allocated and filled + // all-true on-the-fly when the first "impossible" value is seen. if constexpr (!likely_many_nulls) { for (int64_t i = 0; i < output_length; ++i) { - if (data[i] == static_cast(input_length)) { + if (ARROW_PREDICT_FALSE(data[i] == static_cast(input_length))) { if (ARROW_PREDICT_FALSE(!validity_buf)) { RETURN_NOT_OK(AllocateValidityBufAndFill(true)); validity = validity_buf->mutable_data_as(); @@ -240,12 +256,12 @@ struct ReverseIndices { } static Status Exec(KernelContext* ctx, const ExecSpan& span, ExecResult* result) { + DCHECK_EQ(span.num_values(), 1); DCHECK(span[0].is_array()); const auto& indices = span[0].array; ARROW_ASSIGN_OR_RAISE( - auto output, ReverseIndicesImpl::Exec(ctx, indices, indices.length, - indices.type->GetSharedPtr())); - result->value = std::move(output); + result->value, ReverseIndicesImpl::Exec(ctx, indices, indices.length, + indices.type->GetSharedPtr())); return Status::OK(); } }; @@ -268,12 +284,11 @@ struct ReverseIndicesChunked { } static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* result) { + DCHECK_EQ(batch.num_values(), 1); DCHECK(batch[0].is_chunked_array()); const auto& indices = batch[0].chunked_array(); - ARROW_ASSIGN_OR_RAISE( - auto output, ReverseIndicesImpl::Exec(ctx, indices, indices->length(), - indices->type())); - *result = std::move(output); + ARROW_ASSIGN_OR_RAISE(*result, ReverseIndicesImpl::Exec( + ctx, indices, indices->length(), indices->type())); return Status::OK(); } }; @@ -295,8 +310,8 @@ void RegisterVectorReverseIndices(FunctionRegistry* registry) { kernel.output_chunked = false; DCHECK_OK(function->AddKernel(std::move(kernel))); }; - for (const auto& ty : IntTypes()) { - add_kernel(ty->id()); + for (const auto& t : IntTypes()) { + add_kernel(t->id()); } DCHECK_OK(registry->AddFunction(std::move(function))); @@ -324,6 +339,7 @@ class PermuteMetaFunction : public MetaFunction { Result ExecuteImpl(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const override { + DCHECK_EQ(args.size(), 2); const auto& values = args[0]; const auto& indices = args[1]; auto* permute_options = checked_cast(options); @@ -336,12 +352,25 @@ class PermuteMetaFunction : public MetaFunction { return Status::Invalid("Indices of permute must be of integer type, got ", indices.type()->ToString()); } + // Apply default options semantics. int64_t output_length = permute_options->output_length; if (output_length < 0) { output_length = values.length(); } + // Internally invoke take(values, reverse_indices(indices)) to implement permute. + // For example, with + // values = [a, b, c, d, e, f, g] + // indices = [null, 0, 3, 2, 4, 1, 1] + // the reverse_indices(indices) is + // [1, 6, 3] if output_length = 3, + // [1, 6, 3, 2, 4, null, null] if output_length = 7. + // and take(values, reverse_indices(indices)) is + // [b, g, d] if output_length = 3, + // [b, g, d, c, e, null, null] if output_length = 7. ReverseIndicesOptions reverse_indices_options{ - output_length, InferSmallestReverseIndicesType(values.length())}; + output_length, + // Use the smallest possible uint type to store reverse indices. + InferSmallestReverseIndicesType(values.length())}; ARROW_ASSIGN_OR_RAISE( auto reverse_indices, CallFunction("reverse_indices", {indices}, &reverse_indices_options, ctx)); @@ -351,14 +380,14 @@ class PermuteMetaFunction : public MetaFunction { private: static std::shared_ptr InferSmallestReverseIndicesType(int64_t input_length) { - if (input_length <= std::numeric_limits::max()) { - return int8(); - } else if (input_length <= std::numeric_limits::max()) { - return int16(); - } else if (input_length <= std::numeric_limits::max()) { - return int32(); + if (input_length <= std::numeric_limits::max()) { + return uint8(); + } else if (input_length <= std::numeric_limits::max()) { + return uint16(); + } else if (input_length <= std::numeric_limits::max()) { + return uint32(); } else { - return int64(); + return uint64(); } } }; diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 17700c9412813..ccd605e121a89 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1922,4 +1922,5 @@ indices. * \(3) For ``indices[i] = x``, ``output[x] = values[i]``. And ``output[x] = null`` if ``x`` does not appear in the input ``indices``. For ``indices[i] = x`` where -``x < 0`` or ``x >= output_length``, ``values[i]`` is ignored. +``x < 0`` or ``x >= output_length``, ``values[i]`` is ignored. If multiple indices +point to the same value, the last one is used.