Skip to content

Commit

Permalink
1. Expose recursive flatten for logical lists on list_flatten kernel …
Browse files Browse the repository at this point in the history
…function

2. Support [Large]ListView for some kernel functions: list_flatten,list_value_length, list_element
3. Support recursive flatten for pyarrow bindinds and simplify [Large]ListView's pyarrow bindings
4. Refactor vector_nested_test.cc for better support [Large]ListView types.
  • Loading branch information
ZhangHuiGui committed Apr 18, 2024
1 parent 117460b commit c400b46
Show file tree
Hide file tree
Showing 13 changed files with 381 additions and 207 deletions.
7 changes: 7 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ static auto kRankOptionsType = GetFunctionOptionsType<RankOptions>(
DataMember("tiebreaker", &RankOptions::tiebreaker));
static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursively", &ListFlattenOptions::recursively));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -224,6 +226,10 @@ PairwiseOptions::PairwiseOptions(int64_t periods)
: FunctionOptions(internal::kPairwiseOptionsType), periods(periods) {}
constexpr char PairwiseOptions::kTypeName[];

ListFlattenOptions::ListFlattenOptions(bool recursively)
: FunctionOptions(internal::kListFlattenOptionsType), recursively(recursively) {}
constexpr char ListFlattenOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -237,6 +243,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
}
} // namespace internal

Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,18 @@ class ARROW_EXPORT PairwiseOptions : public FunctionOptions {
int64_t periods = 1;
};

/// \brief Options for list_flatten function
class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
public:
explicit ListFlattenOptions(bool recursively = false);
static constexpr char const kTypeName[] = "ListFlattenOptions";
static ListFlattenOptions Defaults() { return ListFlattenOptions(); }

/// Control the version of 'Flatten' that keeps recursively flattening
/// until an array of non-list values is reached.
bool recursively = false;
};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,14 @@ Result<TypeHolder> LastType(KernelContext*, const std::vector<TypeHolder>& types
}

Result<TypeHolder> ListValuesType(KernelContext*, const std::vector<TypeHolder>& args) {
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
return list_type.value_type().get();
auto list_type = checked_cast<const BaseListType*>(args[0].type);
auto value_type = list_type->value_type().get();
for (auto value_kind = value_type->id();
is_list(value_kind) || is_list_view(value_kind); value_kind = value_type->id()) {
list_type = checked_cast<const BaseListType*>(list_type->value_type().get());
value_type = list_type->value_type().get();
}
return value_type;
}

void EnsureDictionaryDecoded(std::vector<TypeHolder>* types) {
Expand Down
49 changes: 39 additions & 10 deletions cpp/src/arrow/compute/kernels/scalar_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,26 @@ namespace {
template <typename Type, typename offset_type = typename Type::offset_type>
Status ListValueLength(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
const ArraySpan& arr = batch[0].array;
const auto kind = arr.type->id();
ArraySpan* out_arr = out->array_span_mutable();
auto out_values = out_arr->GetValues<offset_type>(1);
const offset_type* offsets = arr.GetValues<offset_type>(1);
// Offsets are always well-defined and monotonic, even for null values
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = offsets[i + 1] - offsets[i];
if (is_list_view(kind)) {
// [Large]ListView's buffer layout:
// buffer1 : valid bitmap
// buffer2 : elements' start offset in current array
// buffer3 : elements' size
//
// It's unnecessary to calculate according offsets.
const auto* sizes = arr.GetValues<offset_type>(2);
for (int64_t i = 0; i < arr.length; i++) {
*out_values++ = sizes[i];
}
} else {
const offset_type* offsets = arr.GetValues<offset_type>(1);
// Offsets are always well-defined and monotonic, even for null values
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = offsets[i + 1] - offsets[i];
}
}
return Status::OK();
}
Expand All @@ -59,6 +73,24 @@ Status FixedSizeListValueLength(KernelContext* ctx, const ExecSpan& batch,
return Status::OK();
}

template <typename InListType>
void AddListValueLengthKernel(ScalarFunction* func,
const std::shared_ptr<DataType>& out_type) {
auto in_type = {InputType(InListType::type_id)};
ScalarKernel kernel(in_type, out_type, ListValueLength<InListType>);
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

void AddListValueLengthKernels(ScalarFunction* func) {
AddListValueLengthKernel<ListType>(func, int32());
AddListValueLengthKernel<LargeListType>(func, int64());
AddListValueLengthKernel<ListViewType>(func, int32());
AddListValueLengthKernel<LargeListViewType>(func, int64());

DCHECK_OK(func->AddKernel({InputType(Type::FIXED_SIZE_LIST)}, int32(),
FixedSizeListValueLength));
}

const FunctionDoc list_value_length_doc{
"Compute list lengths",
("`lists` must have a list-like type.\n"
Expand Down Expand Up @@ -399,6 +431,8 @@ void AddListElementKernels(ScalarFunction* func) {
void AddListElementKernels(ScalarFunction* func) {
AddListElementKernels<ListType, ListElement>(func);
AddListElementKernels<LargeListType, ListElement>(func);
AddListElementKernels<ListViewType, ListElement>(func);
AddListElementKernels<LargeListViewType, ListElement>(func);
AddListElementKernels<FixedSizeListType, FixedSizeListElement>(func);
}

Expand Down Expand Up @@ -824,12 +858,7 @@ const FunctionDoc map_lookup_doc{
void RegisterScalarNested(FunctionRegistry* registry) {
auto list_value_length = std::make_shared<ScalarFunction>(
"list_value_length", Arity::Unary(), list_value_length_doc);
DCHECK_OK(list_value_length->AddKernel({InputType(Type::LIST)}, int32(),
ListValueLength<ListType>));
DCHECK_OK(list_value_length->AddKernel({InputType(Type::FIXED_SIZE_LIST)}, int32(),
FixedSizeListValueLength));
DCHECK_OK(list_value_length->AddKernel({InputType(Type::LARGE_LIST)}, int64(),
ListValueLength<LargeListType>));
AddListValueLengthKernels(list_value_length.get());
DCHECK_OK(registry->AddFunction(std::move(list_value_length)));

auto list_element =
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_nested_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ namespace arrow {
namespace compute {

static std::shared_ptr<DataType> GetOffsetType(const DataType& type) {
return type.id() == Type::LIST ? int32() : int64();
switch (type.id()) {
case Type::LIST:
case Type::LIST_VIEW:
return int32();
case Type::LARGE_LIST:
case Type::LARGE_LIST_VIEW:
return int64();
default:
Unreachable("Unexpected type");
}
}

TEST(TestScalarNested, ListValueLength) {
for (auto ty : {list(int32()), large_list(int32())}) {
for (auto ty : {list(int32()), large_list(int32()), list_view(int32()),
large_list_view(int32())}) {
CheckScalarUnary("list_value_length", ty, "[[0, null, 1], null, [2, 3], []]",
GetOffsetType(*ty), "[3, null, 2, 0]");
}
Expand All @@ -47,7 +57,8 @@ TEST(TestScalarNested, ListValueLength) {
TEST(TestScalarNested, ListElementNonFixedListWithNulls) {
auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9], null]";
for (auto ty : NumericTypes()) {
for (auto list_type : {list(ty), large_list(ty)}) {
for (auto list_type :
{list(ty), large_list(ty), list_view(ty), large_list_view(ty)}) {
auto input = ArrayFromJSON(list_type, sample);
auto null_input = ArrayFromJSON(list_type, "[null]");
for (auto index_type : IntTypes()) {
Expand Down
50 changes: 40 additions & 10 deletions cpp/src/arrow/compute/kernels/vector_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// Vector kernels involving nested types

#include "arrow/array/array_base.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/result.h"
#include "arrow/visit_type_inline.h"
Expand All @@ -29,8 +30,16 @@ namespace {

template <typename Type>
Status ListFlatten(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto recursively = OptionsWrapper<ListFlattenOptions>::Get(ctx).recursively;
typename TypeTraits<Type>::ArrayType list_array(batch[0].array.ToArrayData());
ARROW_ASSIGN_OR_RAISE(auto result, list_array.Flatten(ctx->memory_pool()));

std::shared_ptr<Array> result;
if (!recursively) {
ARROW_ASSIGN_OR_RAISE(result, list_array.Flatten(ctx->memory_pool()));
} else {
ARROW_ASSIGN_OR_RAISE(result, list_array.FlattenRecursively(ctx->memory_pool()));
}

out->value = std::move(result->data());
return Status::OK();
}
Expand Down Expand Up @@ -70,6 +79,10 @@ struct ListParentIndicesArray {

Status Visit(const LargeListType& type) { return VisitList(type); }

Status Visit(const ListViewType& type) { return VisitList(type); }

Status Visit(const LargeListViewType& type) { return VisitList(type); }

Status Visit(const FixedSizeListType& type) {
using offset_type = typename FixedSizeListType::offset_type;
const offset_type slot_length = type.list_size();
Expand Down Expand Up @@ -110,7 +123,7 @@ const FunctionDoc list_flatten_doc(
("`lists` must have a list-like type.\n"
"Return an array with the top list level flattened.\n"
"Top-level null values in `lists` do not emit anything in the input."),
{"lists"});
{"lists"}, "ListFlattenOptions");

const FunctionDoc list_parent_indices_doc(
"Compute parent indices of nested list values",
Expand Down Expand Up @@ -153,17 +166,34 @@ class ListParentIndicesFunction : public MetaFunction {
}
};

const ListFlattenOptions* GetDefaultListFlattenOptions() {
static const auto kDefaultListFlattenOptions = ListFlattenOptions::Defaults();
return &kDefaultListFlattenOptions;
}

template <typename InListType>
void AddBaseListFlattenKernels(VectorFunction* func) {
auto in_type = {InputType(InListType::type_id)};
auto out_type = OutputType(ListValuesType);
VectorKernel kernel(in_type, out_type, ListFlatten<InListType>,
OptionsWrapper<ListFlattenOptions>::Init);
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

void AddBaseListFlattenKernels(VectorFunction* func) {
AddBaseListFlattenKernels<ListType>(func);
AddBaseListFlattenKernels<LargeListType>(func);
AddBaseListFlattenKernels<FixedSizeListType>(func);
AddBaseListFlattenKernels<ListViewType>(func);
AddBaseListFlattenKernels<LargeListViewType>(func);
}

} // namespace

void RegisterVectorNested(FunctionRegistry* registry) {
auto flatten =
std::make_shared<VectorFunction>("list_flatten", Arity::Unary(), list_flatten_doc);
DCHECK_OK(flatten->AddKernel({Type::LIST}, OutputType(ListValuesType),
ListFlatten<ListType>));
DCHECK_OK(flatten->AddKernel({Type::FIXED_SIZE_LIST}, OutputType(ListValuesType),
ListFlatten<FixedSizeListType>));
DCHECK_OK(flatten->AddKernel({Type::LARGE_LIST}, OutputType(ListValuesType),
ListFlatten<LargeListType>));
auto flatten = std::make_shared<VectorFunction>(
"list_flatten", Arity::Unary(), list_flatten_doc, GetDefaultListFlattenOptions());
AddBaseListFlattenKernels(flatten.get());
DCHECK_OK(registry->AddFunction(std::move(flatten)));

DCHECK_OK(registry->AddFunction(std::make_shared<ListParentIndicesFunction>()));
Expand Down
Loading

0 comments on commit c400b46

Please sign in to comment.