diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 715373fad8e..ac1b570534f 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -356,10 +356,10 @@ if(ARROW_COMPUTE) compute/kernels/scalar_set_lookup.cc compute/kernels/scalar_string.cc compute/kernels/scalar_validity.cc - compute/kernels/vector_filter.cc + compute/kernels/util_internal.cc compute/kernels/vector_hash.cc - compute/kernels/vector_sort.cc - compute/kernels/vector_take.cc) + compute/kernels/vector_selection.cc + compute/kernels/vector_sort.cc) endif() if(ARROW_FILESYSTEM) diff --git a/cpp/src/arrow/array/array_binary.h b/cpp/src/arrow/array/array_binary.h index b291de3ab72..c54e504b6e8 100644 --- a/cpp/src/arrow/array/array_binary.h +++ b/cpp/src/arrow/array/array_binary.h @@ -85,6 +85,8 @@ class BaseBinaryArray : public FlatArray { return raw_value_offsets_ + data_->offset; } + const uint8_t* raw_data() const { return raw_data_; } + /// \brief Return the data buffer absolute offset of the data for the value /// at the passed index. /// diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index d0d67e76fb2..dd9c43910f1 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -21,8 +21,8 @@ #include #include +#include "arrow/array/builder_primitive.h" #include "arrow/compute/exec.h" -#include "arrow/compute/kernels/vector_selection_internal.h" #include "arrow/datum.h" #include "arrow/record_batch.h" #include "arrow/result.h" @@ -65,6 +65,9 @@ Result> ValueCounts(const Datum& value, ExecContext* ctx) return result.make_array(); } +// ---------------------------------------------------------------------- +// Filter- and take-related selection functions + Result Filter(const Datum& values, const Datum& filter, const FilterOptions& options, ExecContext* ctx) { // Invoke metafunction which deals with Datum kinds other than just Array, diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 48f9102eb42..166bc1096d6 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -38,7 +38,10 @@ struct FilterOptions : public FunctionOptions { EMIT_NULL, }; - static FilterOptions Defaults() { return FilterOptions{}; } + explicit FilterOptions(NullSelectionBehavior null_selection = DROP) + : null_selection_behavior(null_selection) {} + + static FilterOptions Defaults() { return FilterOptions(); } NullSelectionBehavior null_selection_behavior = DROP; }; @@ -64,6 +67,24 @@ Result Filter(const Datum& values, const Datum& filter, const FilterOptions& options = FilterOptions::Defaults(), ExecContext* ctx = NULLPTR); +namespace internal { + +// These internal functions are implemented in kernels/vector_selection.cc + +/// \brief Return the number of selected indices in the boolean filter +ARROW_EXPORT +int64_t GetFilterOutputSize(const ArrayData& filter, + FilterOptions::NullSelectionBehavior null_selection); + +/// \brief Compute uint64 selection indices for use with Take given a boolean +/// filter +ARROW_EXPORT +Result> GetTakeIndices( + const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, + MemoryPool* memory_pool = default_memory_pool()); + +} // namespace internal + struct ARROW_EXPORT TakeOptions : public FunctionOptions { explicit TakeOptions(bool boundscheck = true) : boundscheck(boundscheck) {} diff --git a/cpp/src/arrow/compute/benchmark_util.h b/cpp/src/arrow/compute/benchmark_util.h index 1259d1b5468..edd2007c2b2 100644 --- a/cpp/src/arrow/compute/benchmark_util.h +++ b/cpp/src/arrow/compute/benchmark_util.h @@ -24,9 +24,11 @@ #include "arrow/util/cpu_info.h" namespace arrow { -namespace compute { using internal::CpuInfo; + +namespace compute { + static CpuInfo* cpu_info = CpuInfo::GetInstance(); static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE); diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 9ff0d0973fd..0082799b212 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -38,9 +38,8 @@ add_arrow_benchmark(scalar_string_benchmark PREFIX "arrow-compute") add_arrow_compute_test(vector_test SOURCES - vector_filter_test.cc vector_hash_test.cc - vector_take_test.cc + vector_selection_test.cc vector_sort_test.cc test_util.cc) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 2f2159ef642..4b64244bdce 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -97,7 +97,7 @@ std::string MakeArray(Elements... elements) { std::copy(elements_as_strings.begin(), elements_as_strings.end(), elements_as_views.begin()); - return "[" + internal::JoinStrings(elements_as_views, ",") + "]"; + return "[" + ::arrow::internal::JoinStrings(elements_as_views, ",") + "]"; } template diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index ace8759d837..4970c830f2b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -46,10 +46,11 @@ #include "arrow/compute/kernels/test_util.h" namespace arrow { -namespace compute { using internal::checked_cast; +namespace compute { + static constexpr const char* kInvalidUtf8 = "\xa0\xa1"; static std::vector> kNumericTypes = { diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 758e10b60d8..8bedb965686 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -36,6 +36,9 @@ #include "arrow/util/checked_cast.h" namespace arrow { + +using internal::BitmapReader; + namespace compute { using util::string_view; @@ -115,8 +118,8 @@ Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs, ArrayFromVector(bitmap, &result); } else { std::vector null_bitmap(array->length()); - auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(), - array->length()); + auto reader = + BitmapReader(array->null_bitmap_data(), array->offset(), array->length()); for (int64_t i = 0; i < array->length(); i++, reader.Next()) { null_bitmap[i] = reader.IsSet(); } @@ -146,8 +149,8 @@ Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& ArrayFromVector(bitmap, &result); } else { std::vector null_bitmap(array->length()); - auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(), - array->length()); + auto reader = + BitmapReader(array->null_bitmap_data(), array->offset(), array->length()); for (int64_t i = 0; i < array->length(); i++, reader.Next()) { null_bitmap[i] = reader.IsSet(); } diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index c4e1f07075e..f01f5c7cc6b 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -39,6 +39,9 @@ // IWYU pragma: end_exports namespace arrow { + +using internal::checked_cast; + namespace compute { template @@ -65,8 +68,8 @@ struct DatumEqual> { static void EnsureEqual(const Datum& lhs, const Datum& rhs) { ASSERT_EQ(lhs.kind(), rhs.kind()); if (lhs.kind() == Datum::SCALAR) { - auto left = internal::checked_cast(lhs.scalar().get()); - auto right = internal::checked_cast(rhs.scalar().get()); + auto left = checked_cast(lhs.scalar().get()); + auto right = checked_cast(rhs.scalar().get()); ASSERT_EQ(left->is_valid, right->is_valid); ASSERT_EQ(left->type->id(), right->type->id()); ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound); @@ -80,8 +83,8 @@ struct DatumEqual> { static void EnsureEqual(const Datum& lhs, const Datum& rhs) { ASSERT_EQ(lhs.kind(), rhs.kind()); if (lhs.kind() == Datum::SCALAR) { - auto left = internal::checked_cast(lhs.scalar().get()); - auto right = internal::checked_cast(rhs.scalar().get()); + auto left = checked_cast(lhs.scalar().get()); + auto right = checked_cast(rhs.scalar().get()); ASSERT_EQ(*left, *right); } } diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc new file mode 100644 index 00000000000..32c6317a104 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/util_internal.h" + +#include + +#include "arrow/array/data.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace internal { + +const uint8_t* GetValidityBitmap(const ArrayData& data) { + const uint8_t* bitmap = nullptr; + if (data.buffers[0]) { + bitmap = data.buffers[0]->data(); + } + return bitmap; +} + +int GetBitWidth(const DataType& type) { + return checked_cast(type).bit_width(); +} + +PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { + PrimitiveArg arg; + arg.is_valid = GetValidityBitmap(arr); + arg.data = arr.buffers[1]->data(); + arg.bit_width = GetBitWidth(*arr.type); + arg.offset = arr.offset; + arg.length = arr.length; + if (arg.bit_width > 1) { + arg.data += arr.offset * arg.bit_width / 8; + } + // This may be kUnknownNullCount + arg.null_count = arr.null_count.load(); + return arg; +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h new file mode 100644 index 00000000000..7ab59965752 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/util_internal.h @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/buffer.h" + +namespace arrow { +namespace compute { +namespace internal { + +// An internal data structure for unpacking a primitive argument to pass to a +// kernel implementation +struct PrimitiveArg { + const uint8_t* is_valid; + // If the bit_width is a multiple of 8 (i.e. not boolean), then "data" should + // be shifted by offset * (bit_width / 8). For bit-packed data, the offset + // must be used when indexing. + const uint8_t* data; + int bit_width; + int64_t length; + int64_t offset; + // This may be kUnknownNullCount if the null_count has not yet been computed, + // so use null_count != 0 to determine "may have nulls". + int64_t null_count; +}; + +// Get validity bitmap data or return nullptr if there is no validity buffer +const uint8_t* GetValidityBitmap(const ArrayData& data); + +int GetBitWidth(const DataType& type); + +// Reduce code size by dealing with the unboxing of the kernel inputs once +// rather than duplicating compiled code to do all these in each kernel. +PrimitiveArg GetPrimitiveArg(const ArrayData& arr); + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_filter.cc b/cpp/src/arrow/compute/kernels/vector_filter.cc deleted file mode 100644 index db21d402e35..00000000000 --- a/cpp/src/arrow/compute/kernels/vector_filter.cc +++ /dev/null @@ -1,248 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/array/array_base.h" -#include "arrow/array/array_primitive.h" -#include "arrow/compute/api_vector.h" -#include "arrow/compute/kernels/common.h" -#include "arrow/compute/kernels/vector_selection_internal.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/visitor_inline.h" - -namespace arrow { -namespace compute { -namespace internal { - -// IndexSequence which yields the indices of positions in a BooleanArray -// which are either null or true -template -class FilterIndexSequence { - public: - // constexpr so we'll never instantiate bounds checking - constexpr bool never_out_of_bounds() const { return true; } - void set_never_out_of_bounds() {} - - constexpr FilterIndexSequence() = default; - - FilterIndexSequence(const BooleanArray& filter, int64_t out_length) - : filter_(&filter), out_length_(out_length) {} - - std::pair Next() { - if (NullSelectionBehavior == FilterOptions::DROP) { - // skip until an index is found at which the filter is true - while (filter_->IsNull(index_) || !filter_->Value(index_)) { - ++index_; - } - return std::make_pair(index_++, true); - } - - // skip until an index is found at which the filter is either null or true - while (filter_->IsValid(index_) && !filter_->Value(index_)) { - ++index_; - } - bool is_valid = filter_->IsValid(index_); - return std::make_pair(index_++, is_valid); - } - - int64_t length() const { return out_length_; } - - int64_t null_count() const { - if (NullSelectionBehavior == FilterOptions::DROP) { - return 0; - } - return filter_->null_count(); - } - - private: - const BooleanArray* filter_ = nullptr; - int64_t index_ = 0, out_length_ = -1; -}; - -int64_t FilterOutputSize(FilterOptions::NullSelectionBehavior null_selection, - const Array& arr) { - const auto& filter = checked_cast(arr); - // TODO(bkietz) this can be optimized. Use Bitmap::VisitWords - int64_t size = 0; - if (null_selection == FilterOptions::EMIT_NULL) { - for (auto i = 0; i < filter.length(); ++i) { - if (filter.IsNull(i) || filter.Value(i)) { - ++size; - } - } - } else { - for (auto i = 0; i < filter.length(); ++i) { - if (filter.IsValid(i) && filter.Value(i)) { - ++size; - } - } - } - return size; -} - -struct FilterState : public KernelState { - explicit FilterState(FilterOptions options) : options(std::move(options)) {} - FilterOptions options; -}; - -std::unique_ptr InitFilter(KernelContext*, const KernelInitArgs& args) { - FilterOptions options; - if (args.options == nullptr) { - options = FilterOptions::Defaults(); - } else { - options = *static_cast(args.options); - } - return std::unique_ptr(new FilterState(std::move(options))); -} - -template -struct FilterFunctor { - using ArrayType = typename TypeTraits::ArrayType; - - template - static void ExecImpl(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - using IS = FilterIndexSequence; - ArrayType values(batch[0].array()); - BooleanArray filter(batch[1].array()); - const int64_t output_size = FilterOutputSize(NullSelection, filter); - std::shared_ptr result; - KERNEL_RETURN_IF_ERROR(ctx, Select(ctx, values, IS(filter, output_size), &result)); - out->value = result->data(); - } - - static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& state = checked_cast(*ctx->state()); - if (state.options.null_selection_behavior == FilterOptions::EMIT_NULL) { - ExecImpl(ctx, batch, out); - } else { - ExecImpl(ctx, batch, out); - } - } -}; - -struct FilterKernelVisitor { - template - Status Visit(const Type&) { - this->result = FilterFunctor::Exec; - return Status::OK(); - } - - Status Create(const DataType& type) { return VisitTypeInline(type, this); } - ArrayKernelExec result; -}; - -Status GetFilterKernel(const DataType& type, ArrayKernelExec* exec) { - FilterKernelVisitor visitor; - RETURN_NOT_OK(visitor.Create(type)); - *exec = visitor.result; - return Status::OK(); -} - -Result> FilterRecordBatch(const RecordBatch& batch, - const Datum& filter, - const FunctionOptions* options, - ExecContext* ctx) { - if (!filter.is_array()) { - return Status::Invalid("Cannot filter a RecordBatch with a filter of kind ", - filter.kind()); - } - - const auto& filter_opts = *static_cast(options); - // TODO: Rewrite this to convert to selection vector and use Take - std::vector> columns(batch.num_columns()); - for (int i = 0; i < batch.num_columns(); ++i) { - ARROW_ASSIGN_OR_RAISE(Datum out, - Filter(batch.column(i)->data(), filter, filter_opts, ctx)); - columns[i] = out.make_array(); - } - - int64_t out_length; - if (columns.size() == 0) { - out_length = - FilterOutputSize(filter_opts.null_selection_behavior, *filter.make_array()); - } else { - out_length = columns[0]->length(); - } - return RecordBatch::Make(batch.schema(), out_length, columns); -} - -Result> FilterTable(const Table& table, const Datum& filter, - const FunctionOptions* options, - ExecContext* ctx) { - auto new_columns = table.columns(); - for (auto& column : new_columns) { - ARROW_ASSIGN_OR_RAISE( - Datum out_column, - Filter(column, filter, *static_cast(options), ctx)); - column = out_column.chunked_array(); - } - return Table::Make(table.schema(), std::move(new_columns)); -} - -class FilterMetaFunction : public MetaFunction { - public: - FilterMetaFunction() : MetaFunction("filter", Arity::Binary()) {} - - Result ExecuteImpl(const std::vector& args, - const FunctionOptions* options, - ExecContext* ctx) const override { - if (args[0].kind() == Datum::RECORD_BATCH) { - auto values_batch = args[0].record_batch(); - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr out_batch, - FilterRecordBatch(*args[0].record_batch(), args[1], options, ctx)); - return Datum(out_batch); - } else if (args[0].kind() == Datum::TABLE) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_table, - FilterTable(*args[0].table(), args[1], options, ctx)); - return Datum(out_table); - } else { - return CallFunction("array_filter", args, options, ctx); - } - } -}; - -void RegisterVectorFilter(FunctionRegistry* registry) { - VectorKernel base; - base.init = InitFilter; - - auto filter = std::make_shared("array_filter", Arity::Binary()); - InputType filter_ty = InputType::Array(boolean()); - OutputType out_ty(FirstType); - - auto AddKernel = [&](InputType in_ty, const DataType& example_type) { - base.signature = KernelSignature::Make({in_ty, filter_ty}, out_ty); - DCHECK_OK(GetFilterKernel(example_type, &base.exec)); - DCHECK_OK(filter->AddKernel(base)); - }; - - for (const auto& value_ty : PrimitiveTypes()) { - AddKernel(InputType::Array(value_ty), *value_ty); - } - // Other types where we may only on the DataType::id - for (const auto& value_ty : ExampleParametricTypes()) { - AddKernel(InputType::Array(value_ty->id()), *value_ty); - } - DCHECK_OK(registry->AddFunction(std::move(filter))); - - // Add filter metafunction - DCHECK_OK(registry->AddFunction(std::make_shared())); -} - -} // namespace internal -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_filter_test.cc b/cpp/src/arrow/compute/kernels/vector_filter_test.cc deleted file mode 100644 index 327789127db..00000000000 --- a/cpp/src/arrow/compute/kernels/vector_filter_test.cc +++ /dev/null @@ -1,721 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include - -#include "arrow/compute/api.h" -#include "arrow/compute/kernels/test_util.h" -#include "arrow/table.h" -#include "arrow/testing/gtest_common.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/testing/random.h" -#include "arrow/testing/util.h" - -namespace arrow { -namespace compute { - -using internal::checked_pointer_cast; -using util::string_view; - -std::shared_ptr CoalesceNullToFalse(std::shared_ptr filter) { - if (filter->null_count() == 0) { - return filter; - } - const auto& data = *filter->data(); - auto is_true = std::make_shared(data.length, data.buffers[1]); - auto is_valid = std::make_shared(data.length, data.buffers[0]); - EXPECT_OK_AND_ASSIGN(Datum out_datum, arrow::compute::And(is_true, is_valid)); - return out_datum.make_array(); -} - -template -class TestFilterKernel : public ::testing::Test { - protected: - TestFilterKernel() { - emit_null_.null_selection_behavior = FilterOptions::EMIT_NULL; - drop_.null_selection_behavior = FilterOptions::DROP; - } - - void AssertFilter(std::shared_ptr values, std::shared_ptr filter, - std::shared_ptr expected) { - // test with EMIT_NULL - ASSERT_OK_AND_ASSIGN(Datum out_datum, - arrow::compute::Filter(values, filter, emit_null_)); - auto actual = out_datum.make_array(); - ASSERT_OK(actual->ValidateFull()); - AssertArraysEqual(*expected, *actual); - - // test with DROP using EMIT_NULL and a coalesced filter - auto coalesced_filter = CoalesceNullToFalse(filter); - ASSERT_OK_AND_ASSIGN(out_datum, - arrow::compute::Filter(values, coalesced_filter, emit_null_)); - expected = out_datum.make_array(); - ASSERT_OK_AND_ASSIGN(out_datum, arrow::compute::Filter(values, filter, drop_)); - actual = out_datum.make_array(); - AssertArraysEqual(*expected, *actual); - } - - void AssertFilter(std::shared_ptr type, const std::string& values, - const std::string& filter, const std::string& expected) { - AssertFilter(ArrayFromJSON(type, values), ArrayFromJSON(boolean(), filter), - ArrayFromJSON(type, expected)); - } - - void ValidateFilter(const std::shared_ptr& values, - const std::shared_ptr& filter_boxed) { - ASSERT_OK_AND_ASSIGN(Datum out_datum, - arrow::compute::Filter(values, filter_boxed, emit_null_)); - auto filtered_emit_null = out_datum.make_array(); - ASSERT_OK(filtered_emit_null->ValidateFull()); - - ASSERT_OK_AND_ASSIGN(out_datum, arrow::compute::Filter(values, filter_boxed, drop_)); - auto filtered_drop = out_datum.make_array(); - ASSERT_OK(filtered_drop->ValidateFull()); - - auto filter = checked_pointer_cast(filter_boxed); - int64_t values_i = 0, emit_null_i = 0, drop_i = 0; - for (; values_i < values->length(); ++values_i, ++emit_null_i, ++drop_i) { - if (filter->IsNull(values_i)) { - ASSERT_LT(emit_null_i, filtered_emit_null->length()); - ASSERT_TRUE(filtered_emit_null->IsNull(emit_null_i)); - // this element was (null) filtered out; don't examine filtered_drop - --drop_i; - continue; - } - if (!filter->Value(values_i)) { - // this element was filtered out; don't examine filtered_emit_null - --emit_null_i; - --drop_i; - continue; - } - ASSERT_LT(emit_null_i, filtered_emit_null->length()); - ASSERT_LT(drop_i, filtered_drop->length()); - ASSERT_TRUE( - values->RangeEquals(values_i, values_i + 1, emit_null_i, filtered_emit_null)); - ASSERT_TRUE(values->RangeEquals(values_i, values_i + 1, drop_i, filtered_drop)); - } - ASSERT_EQ(emit_null_i, filtered_emit_null->length()); - ASSERT_EQ(drop_i, filtered_drop->length()); - } - - FilterOptions emit_null_, drop_; -}; - -class TestFilterKernelWithNull : public TestFilterKernel { - protected: - void AssertFilter(const std::string& values, const std::string& filter, - const std::string& expected) { - TestFilterKernel::AssertFilter(ArrayFromJSON(null(), values), - ArrayFromJSON(boolean(), filter), - ArrayFromJSON(null(), expected)); - } -}; - -TEST_F(TestFilterKernelWithNull, FilterNull) { - this->AssertFilter("[]", "[]", "[]"); - - this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]"); - this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]"); -} - -class TestFilterKernelWithBoolean : public TestFilterKernel { - protected: - void AssertFilter(const std::string& values, const std::string& filter, - const std::string& expected) { - TestFilterKernel::AssertFilter(ArrayFromJSON(boolean(), values), - ArrayFromJSON(boolean(), filter), - ArrayFromJSON(boolean(), expected)); - } -}; - -TEST_F(TestFilterKernelWithBoolean, FilterBoolean) { - this->AssertFilter("[]", "[]", "[]"); - - this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]"); - this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]"); - this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]"); -} - -template -class TestFilterKernelWithNumeric : public TestFilterKernel { - protected: - std::shared_ptr type_singleton() { - return TypeTraits::type_singleton(); - } -}; - -TYPED_TEST_SUITE(TestFilterKernelWithNumeric, NumericArrowTypes); -TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) { - auto type = this->type_singleton(); - this->AssertFilter(type, "[]", "[]", "[]"); - - this->AssertFilter(type, "[9]", "[0]", "[]"); - this->AssertFilter(type, "[9]", "[1]", "[9]"); - this->AssertFilter(type, "[9]", "[null]", "[null]"); - this->AssertFilter(type, "[null]", "[0]", "[]"); - this->AssertFilter(type, "[null]", "[1]", "[null]"); - this->AssertFilter(type, "[null]", "[null]", "[null]"); - - this->AssertFilter(type, "[7, 8, 9]", "[0, 1, 0]", "[8]"); - this->AssertFilter(type, "[7, 8, 9]", "[1, 0, 1]", "[7, 9]"); - this->AssertFilter(type, "[null, 8, 9]", "[0, 1, 0]", "[8]"); - this->AssertFilter(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8]"); - this->AssertFilter(type, "[7, 8, 9]", "[1, null, 1]", "[7, null, 9]"); - - this->AssertFilter(ArrayFromJSON(type, "[7, 8, 9]"), - ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3, 3), - ArrayFromJSON(type, "[7, 9]")); - - ASSERT_RAISES(Invalid, - arrow::compute::Filter(ArrayFromJSON(type, "[7, 8, 9]"), - ArrayFromJSON(boolean(), "[]"), this->emit_null_)); - ASSERT_RAISES(Invalid, - arrow::compute::Filter(ArrayFromJSON(type, "[7, 8, 9]"), - ArrayFromJSON(boolean(), "[]"), this->drop_)); -} - -TYPED_TEST(TestFilterKernelWithNumeric, FilterRandomNumeric) { - auto rand = random::RandomArrayGenerator(kRandomSeed); - for (size_t i = 3; i < 10; i++) { - const int64_t length = static_cast(1ULL << i); - for (auto null_probability : {0.0, 0.01, 0.25, 1.0}) { - for (auto filter_probability : {0.0, 0.1, 0.5, 1.0}) { - auto values = rand.Numeric(length, 0, 127, null_probability); - auto filter = rand.Boolean(length, filter_probability, null_probability); - this->ValidateFilter(values, filter); - } - } - } -} - -template -using Comparator = bool(CType, CType); - -template -Comparator* GetComparator(CompareOperator op) { - static Comparator* cmp[] = { - // EQUAL - [](CType l, CType r) { return l == r; }, - // NOT_EQUAL - [](CType l, CType r) { return l != r; }, - // GREATER - [](CType l, CType r) { return l > r; }, - // GREATER_EQUAL - [](CType l, CType r) { return l >= r; }, - // LESS - [](CType l, CType r) { return l < r; }, - // LESS_EQUAL - [](CType l, CType r) { return l <= r; }, - }; - return cmp[op]; -} - -template ::CType> -std::shared_ptr CompareAndFilter(const CType* data, int64_t length, Fn&& fn) { - std::vector filtered; - filtered.reserve(length); - std::copy_if(data, data + length, std::back_inserter(filtered), std::forward(fn)); - std::shared_ptr filtered_array; - ArrayFromVector(filtered, &filtered_array); - return filtered_array; -} - -template ::CType> -std::shared_ptr CompareAndFilter(const CType* data, int64_t length, CType val, - CompareOperator op) { - auto cmp = GetComparator(op); - return CompareAndFilter(data, length, [&](CType e) { return cmp(e, val); }); -} - -template ::CType> -std::shared_ptr CompareAndFilter(const CType* data, int64_t length, - const CType* other, CompareOperator op) { - auto cmp = GetComparator(op); - return CompareAndFilter(data, length, [&](CType e) { return cmp(e, *other++); }); -} - -TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) { - using ScalarType = typename TypeTraits::ScalarType; - using ArrayType = typename TypeTraits::ArrayType; - using CType = typename TypeTraits::CType; - - auto rand = random::RandomArrayGenerator(kRandomSeed); - for (size_t i = 3; i < 10; i++) { - const int64_t length = static_cast(1ULL << i); - // TODO(bkietz) rewrite with some nulls - auto array = - checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); - CType c_fifty = 50; - auto fifty = std::make_shared(c_fifty); - for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { - ASSERT_OK_AND_ASSIGN(Datum selection, arrow::compute::Compare(array, Datum(fifty), - CompareOptions(op))); - ASSERT_OK_AND_ASSIGN(Datum filtered, arrow::compute::Filter(array, selection, {})); - auto filtered_array = filtered.make_array(); - ASSERT_OK(filtered_array->ValidateFull()); - auto expected = - CompareAndFilter(array->raw_values(), array->length(), c_fifty, op); - ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); - } - } -} - -TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) { - using ArrayType = typename TypeTraits::ArrayType; - - auto rand = random::RandomArrayGenerator(kRandomSeed); - for (size_t i = 3; i < 10; i++) { - const int64_t length = static_cast(1ULL << i); - auto lhs = checked_pointer_cast( - rand.Numeric(length, 0, 100, /*null_probability=*/0.0)); - auto rhs = checked_pointer_cast( - rand.Numeric(length, 0, 100, /*null_probability=*/0.0)); - for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { - ASSERT_OK_AND_ASSIGN(Datum selection, - arrow::compute::Compare(lhs, rhs, CompareOptions(op))); - ASSERT_OK_AND_ASSIGN(Datum filtered, arrow::compute::Filter(lhs, selection, {})); - auto filtered_array = filtered.make_array(); - ASSERT_OK(filtered_array->ValidateFull()); - auto expected = CompareAndFilter(lhs->raw_values(), lhs->length(), - rhs->raw_values(), op); - ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); - } - } -} - -TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) { - using ScalarType = typename TypeTraits::ScalarType; - using ArrayType = typename TypeTraits::ArrayType; - using CType = typename TypeTraits::CType; - - auto rand = random::RandomArrayGenerator(kRandomSeed); - for (size_t i = 3; i < 10; i++) { - const int64_t length = static_cast(1ULL << i); - auto array = checked_pointer_cast( - rand.Numeric(length, 0, 100, /*null_probability=*/0.0)); - CType c_fifty = 50, c_hundred = 100; - auto fifty = std::make_shared(c_fifty); - auto hundred = std::make_shared(c_hundred); - ASSERT_OK_AND_ASSIGN( - Datum greater_than_fifty, - arrow::compute::Compare(array, Datum(fifty), CompareOptions(GREATER))); - ASSERT_OK_AND_ASSIGN( - Datum less_than_hundred, - arrow::compute::Compare(array, Datum(hundred), CompareOptions(LESS))); - ASSERT_OK_AND_ASSIGN(Datum selection, - arrow::compute::And(greater_than_fifty, less_than_hundred)); - ASSERT_OK_AND_ASSIGN(Datum filtered, arrow::compute::Filter(array, selection, {})); - auto filtered_array = filtered.make_array(); - ASSERT_OK(filtered_array->ValidateFull()); - auto expected = CompareAndFilter( - array->raw_values(), array->length(), - [&](CType e) { return (e > c_fifty) && (e < c_hundred); }); - ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); - } -} - -using StringTypes = - ::testing::Types; - -template -class TestFilterKernelWithString : public TestFilterKernel { - protected: - std::shared_ptr value_type() { - return TypeTraits::type_singleton(); - } - - void AssertFilter(const std::string& values, const std::string& filter, - const std::string& expected) { - TestFilterKernel::AssertFilter(ArrayFromJSON(value_type(), values), - ArrayFromJSON(boolean(), filter), - ArrayFromJSON(value_type(), expected)); - } - - void AssertFilterDictionary(const std::string& dictionary_values, - const std::string& dictionary_filter, - const std::string& filter, - const std::string& expected_filter) { - auto dict = ArrayFromJSON(value_type(), dictionary_values); - auto type = dictionary(int8(), value_type()); - ASSERT_OK_AND_ASSIGN(auto values, - DictionaryArray::FromArrays( - type, ArrayFromJSON(int8(), dictionary_filter), dict)); - ASSERT_OK_AND_ASSIGN( - auto expected, - DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_filter), dict)); - auto take_filter = ArrayFromJSON(boolean(), filter); - TestFilterKernel::AssertFilter(values, take_filter, expected); - } -}; - -TYPED_TEST_SUITE(TestFilterKernelWithString, StringTypes); - -TYPED_TEST(TestFilterKernelWithString, FilterString) { - this->AssertFilter(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["b"])"); - this->AssertFilter(R"([null, "b", "c"])", "[0, 1, 0]", R"(["b"])"); - this->AssertFilter(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b"])"); -} - -TYPED_TEST(TestFilterKernelWithString, FilterDictionary) { - auto dict = R"(["a", "b", "c", "d", "e"])"; - this->AssertFilterDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[4]"); - this->AssertFilterDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[4]"); - this->AssertFilterDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4]"); -} - -class TestFilterKernelWithList : public TestFilterKernel { - public: -}; - -TEST_F(TestFilterKernelWithList, FilterListInt32) { - std::string list_json = "[[], [1,2], null, [3]]"; - this->AssertFilter(list(int32()), list_json, "[0, 0, 0, 0]", "[]"); - this->AssertFilter(list(int32()), list_json, "[0, 1, 1, null]", "[[1,2], null, null]"); - this->AssertFilter(list(int32()), list_json, "[0, 0, 1, null]", "[null, null]"); - this->AssertFilter(list(int32()), list_json, "[1, 0, 0, 1]", "[[], [3]]"); - this->AssertFilter(list(int32()), list_json, "[1, 1, 1, 1]", list_json); - this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]"); -} - -TEST_F(TestFilterKernelWithList, FilterListListInt32) { - std::string list_json = R"([ - [], - [[1], [2, null, 2], []], - null, - [[3, null], null] - ])"; - auto type = list(list(int32())); - this->AssertFilter(type, list_json, "[0, 0, 0, 0]", "[]"); - this->AssertFilter(type, list_json, "[0, 1, 1, null]", R"([ - [[1], [2, null, 2], []], - null, - null - ])"); - this->AssertFilter(type, list_json, "[0, 0, 1, null]", "[null, null]"); - this->AssertFilter(type, list_json, "[1, 0, 0, 1]", R"([ - [], - [[3, null], null] - ])"); - this->AssertFilter(type, list_json, "[1, 1, 1, 1]", list_json); - this->AssertFilter(type, list_json, "[0, 1, 0, 1]", R"([ - [[1], [2, null, 2], []], - [[3, null], null] - ])"); -} - -class TestFilterKernelWithLargeList : public TestFilterKernel {}; - -TEST_F(TestFilterKernelWithLargeList, FilterListInt32) { - std::string list_json = "[[], [1,2], null, [3]]"; - this->AssertFilter(large_list(int32()), list_json, "[0, 0, 0, 0]", "[]"); - this->AssertFilter(large_list(int32()), list_json, "[0, 1, 1, null]", - "[[1,2], null, null]"); -} - -class TestFilterKernelWithFixedSizeList : public TestFilterKernel {}; - -TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) { - std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; - this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 0, 0]", "[]"); - this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 1, null]", - "[[1, null, 3], [4, 5, 6], null]"); - this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 1, null]", - "[[4, 5, 6], null]"); - this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[1, 1, 1, 1]", list_json); - this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 0, 1]", - "[[1, null, 3], [7, 8, null]]"); -} - -class TestFilterKernelWithMap : public TestFilterKernel {}; - -TEST_F(TestFilterKernelWithMap, FilterMapStringToInt32) { - std::string map_json = R"([ - [["joe", 0], ["mark", null]], - null, - [["cap", 8]], - [] - ])"; - this->AssertFilter(map(utf8(), int32()), map_json, "[0, 0, 0, 0]", "[]"); - this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 1, null]", R"([ - null, - [["cap", 8]], - null - ])"); - this->AssertFilter(map(utf8(), int32()), map_json, "[1, 1, 1, 1]", map_json); - this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 0, 1]", "[null, []]"); -} - -class TestFilterKernelWithStruct : public TestFilterKernel {}; - -TEST_F(TestFilterKernelWithStruct, FilterStruct) { - auto struct_type = struct_({field("a", int32()), field("b", utf8())}); - auto struct_json = R"([ - null, - {"a": 1, "b": ""}, - {"a": 2, "b": "hello"}, - {"a": 4, "b": "eh"} - ])"; - this->AssertFilter(struct_type, struct_json, "[0, 0, 0, 0]", "[]"); - this->AssertFilter(struct_type, struct_json, "[0, 1, 1, null]", R"([ - {"a": 1, "b": ""}, - {"a": 2, "b": "hello"}, - null - ])"); - this->AssertFilter(struct_type, struct_json, "[1, 1, 1, 1]", struct_json); - this->AssertFilter(struct_type, struct_json, "[1, 0, 1, 0]", R"([ - null, - {"a": 2, "b": "hello"} - ])"); -} - -class TestFilterKernelWithUnion : public TestFilterKernel {}; - -TEST_F(TestFilterKernelWithUnion, FilterUnion) { - for (auto union_ : UnionTypeFactories()) { - auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ - null, - [2, 222], - [5, "hello"], - [5, "eh"], - null, - [2, 111] - ])"; - this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0]", "[]"); - this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1]", R"([ - [2, 222], - [5, "hello"], - null, - [2, 111] - ])"); - this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0]", R"([ - null, - [5, "hello"], - null - ])"); - this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1]", union_json); - } -} - -class TestFilterKernelWithRecordBatch : public TestFilterKernel { - public: - void AssertFilter(const std::shared_ptr& schm, const std::string& batch_json, - const std::string& selection, FilterOptions options, - const std::string& expected_batch) { - std::shared_ptr actual; - - ASSERT_OK(this->Filter(schm, batch_json, selection, options, &actual)); - ASSERT_OK(actual->ValidateFull()); - ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); - } - - Status Filter(const std::shared_ptr& schm, const std::string& batch_json, - const std::string& selection, FilterOptions options, - std::shared_ptr* out) { - auto batch = RecordBatchFromJSON(schm, batch_json); - ARROW_ASSIGN_OR_RAISE( - Datum out_datum, - arrow::compute::Filter(batch, ArrayFromJSON(boolean(), selection), options)); - *out = out_datum.record_batch(); - return Status::OK(); - } -}; - -TEST_F(TestFilterKernelWithRecordBatch, FilterRecordBatch) { - std::vector> fields = {field("a", int32()), field("b", utf8())}; - auto schm = schema(fields); - - auto batch_json = R"([ - {"a": null, "b": "yo"}, - {"a": 1, "b": ""}, - {"a": 2, "b": "hello"}, - {"a": 4, "b": "eh"} - ])"; - for (auto options : {this->emit_null_, this->drop_}) { - this->AssertFilter(schm, batch_json, "[0, 0, 0, 0]", options, "[]"); - this->AssertFilter(schm, batch_json, "[1, 1, 1, 1]", options, batch_json); - this->AssertFilter(schm, batch_json, "[1, 0, 1, 0]", options, R"([ - {"a": null, "b": "yo"}, - {"a": 2, "b": "hello"} - ])"); - } - - this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->drop_, R"([ - {"a": 1, "b": ""}, - {"a": 2, "b": "hello"} - ])"); - - this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->emit_null_, R"([ - {"a": 1, "b": ""}, - {"a": 2, "b": "hello"}, - {"a": null, "b": null} - ])"); -} - -class TestFilterKernelWithChunkedArray : public TestFilterKernel { - public: - void AssertFilter(const std::shared_ptr& type, - const std::vector& values, const std::string& filter, - const std::vector& expected) { - std::shared_ptr actual; - ASSERT_OK(this->FilterWithArray(type, values, filter, &actual)); - ASSERT_OK(actual->ValidateFull()); - AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); - } - - void AssertChunkedFilter(const std::shared_ptr& type, - const std::vector& values, - const std::vector& filter, - const std::vector& expected) { - std::shared_ptr actual; - ASSERT_OK(this->FilterWithChunkedArray(type, values, filter, &actual)); - ASSERT_OK(actual->ValidateFull()); - AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); - } - - Status FilterWithArray(const std::shared_ptr& type, - const std::vector& values, - const std::string& filter, std::shared_ptr* out) { - ARROW_ASSIGN_OR_RAISE(Datum out_datum, - arrow::compute::Filter(ChunkedArrayFromJSON(type, values), - ArrayFromJSON(boolean(), filter), {})); - *out = out_datum.chunked_array(); - return Status::OK(); - } - - Status FilterWithChunkedArray(const std::shared_ptr& type, - const std::vector& values, - const std::vector& filter, - std::shared_ptr* out) { - ARROW_ASSIGN_OR_RAISE( - Datum out_datum, - arrow::compute::Filter(ChunkedArrayFromJSON(type, values), - ChunkedArrayFromJSON(boolean(), filter), {})); - *out = out_datum.chunked_array(); - return Status::OK(); - } -}; - -TEST_F(TestFilterKernelWithChunkedArray, FilterChunkedArray) { - this->AssertFilter(int8(), {"[]"}, "[]", {}); - this->AssertChunkedFilter(int8(), {"[]"}, {"[]"}, {}); - - this->AssertFilter(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0]", {"[8]"}); - this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0]", "[1, 0]"}, {"[8]"}); - this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0, 1]", "[0]"}, {"[8]"}); - - std::shared_ptr arr; - ASSERT_RAISES( - Invalid, this->FilterWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 1, 1]", &arr)); - ASSERT_RAISES(Invalid, this->FilterWithChunkedArray(int8(), {"[7]", "[8, 9]"}, - {"[0, 1, 0]", "[1, 1]"}, &arr)); -} - -class TestFilterKernelWithTable : public TestFilterKernel
{ - public: - void AssertFilter(const std::shared_ptr& schm, - const std::vector& table_json, const std::string& filter, - FilterOptions options, - const std::vector& expected_table) { - std::shared_ptr
actual; - - ASSERT_OK(this->FilterWithArray(schm, table_json, filter, options, &actual)); - ASSERT_OK(actual->ValidateFull()); - ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); - } - - void AssertChunkedFilter(const std::shared_ptr& schm, - const std::vector& table_json, - const std::vector& filter, FilterOptions options, - const std::vector& expected_table) { - std::shared_ptr
actual; - - ASSERT_OK(this->FilterWithChunkedArray(schm, table_json, filter, options, &actual)); - ASSERT_OK(actual->ValidateFull()); - AssertTablesEqual(*TableFromJSON(schm, expected_table), *actual, - /*same_chunk_layout=*/false); - } - - Status FilterWithArray(const std::shared_ptr& schm, - const std::vector& values, - const std::string& filter, FilterOptions options, - std::shared_ptr
* out) { - ARROW_ASSIGN_OR_RAISE( - Datum out_datum, - arrow::compute::Filter(TableFromJSON(schm, values), - ArrayFromJSON(boolean(), filter), options)); - *out = out_datum.table(); - return Status::OK(); - } - - Status FilterWithChunkedArray(const std::shared_ptr& schm, - const std::vector& values, - const std::vector& filter, - FilterOptions options, std::shared_ptr
* out) { - ARROW_ASSIGN_OR_RAISE( - Datum out_datum, - arrow::compute::Filter(TableFromJSON(schm, values), - ChunkedArrayFromJSON(boolean(), filter), options)); - *out = out_datum.table(); - return Status::OK(); - } -}; - -TEST_F(TestFilterKernelWithTable, FilterTable) { - std::vector> fields = {field("a", int32()), field("b", utf8())}; - auto schm = schema(fields); - - std::vector table_json = {R"([ - {"a": null, "b": "yo"}, - {"a": 1, "b": ""} - ])", - R"([ - {"a": 2, "b": "hello"}, - {"a": 4, "b": "eh"} - ])"}; - for (auto options : {this->emit_null_, this->drop_}) { - this->AssertFilter(schm, table_json, "[0, 0, 0, 0]", options, {}); - this->AssertChunkedFilter(schm, table_json, {"[0]", "[0, 0, 0]"}, options, {}); - this->AssertFilter(schm, table_json, "[1, 1, 1, 1]", options, table_json); - this->AssertChunkedFilter(schm, table_json, {"[1]", "[1, 1, 1]"}, options, - table_json); - } - - std::vector expected_emit_null = {R"([ - {"a": 1, "b": ""} - ])", - R"([ - {"a": 2, "b": "hello"}, - {"a": null, "b": null} - ])"}; - this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->emit_null_, - expected_emit_null); - this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->emit_null_, - expected_emit_null); - - std::vector expected_drop = {R"([{"a": 1, "b": ""}])", - R"([{"a": 2, "b": "hello"}])"}; - this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->drop_, expected_drop); - this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->drop_, - expected_drop); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_selection.cc b/cpp/src/arrow/compute/kernels/vector_selection.cc new file mode 100644 index 00000000000..77ec0289ee5 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_selection.cc @@ -0,0 +1,1826 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/array_binary.h" +#include "arrow/array/array_dict.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/array/concatenate.h" +#include "arrow/buffer_builder.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/extension_type.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/bitmap_reader.h" +#include "arrow/util/int_util.h" + +namespace arrow { + +using internal::BinaryBitBlockCounter; +using internal::BitBlockCount; +using internal::BitBlockCounter; +using internal::BitmapReader; +using internal::CopyBitmap; +using internal::CountSetBits; +using internal::GetArrayView; +using internal::IndexBoundsCheck; +using internal::OptionalBitBlockCounter; +using internal::OptionalBitIndexer; + +namespace compute { +namespace internal { + +int64_t GetFilterOutputSize(const ArrayData& filter, + FilterOptions::NullSelectionBehavior null_selection) { + int64_t output_size = 0; + if (filter.null_count.load() != 0) { + const uint8_t* filter_is_valid = filter.buffers[0]->data(); + BinaryBitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset, + filter_is_valid, filter.offset, filter.length); + int64_t position = 0; + if (null_selection == FilterOptions::EMIT_NULL) { + while (position < filter.length) { + BitBlockCount block = bit_counter.NextOrNotWord(); + output_size += block.popcount; + position += block.length; + } + } else { + while (position < filter.length) { + BitBlockCount block = bit_counter.NextAndWord(); + output_size += block.popcount; + position += block.length; + } + } + } else { + // The filter has no nulls, so we can use CountSetBits + output_size = CountSetBits(filter.buffers[1]->data(), filter.offset, filter.length); + } + return output_size; +} + +template +Result> GetTakeIndicesImpl( + const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, + MemoryPool* memory_pool) { + using T = typename IndexType::c_type; + typename TypeTraits::BuilderType builder(memory_pool); + + const uint8_t* filter_data = filter.buffers[1]->data(); + BitBlockCounter data_counter(filter_data, filter.offset, filter.length); + + // The position relative to the start of the filter + T position = 0; + + // The current position taking the filter offset into account + int64_t position_with_offset = filter.offset; + if (filter.null_count != 0) { + // The filter may have nulls, so we scan the validity bitmap and the filter + // data bitmap together, branching on the null selection type. + const uint8_t* filter_is_valid = filter.buffers[0]->data(); + + // To count blocks whether filter_data[i] || !filter_is_valid[i] + BinaryBitBlockCounter filter_counter(filter_data, filter.offset, filter_is_valid, + filter.offset, filter.length); + if (null_selection == FilterOptions::DROP) { + while (position < filter.length) { + BitBlockCount and_block = filter_counter.NextAndWord(); + RETURN_NOT_OK(builder.Reserve(and_block.popcount)); + if (and_block.AllSet()) { + // All the values are selected and non-null + for (int64_t i = 0; i < and_block.length; ++i) { + builder.UnsafeAppend(position++); + } + position_with_offset += and_block.length; + } else if (!and_block.NoneSet()) { + // Some of the values are false or null + for (int64_t i = 0; i < and_block.length; ++i) { + if (BitUtil::GetBit(filter_is_valid, position_with_offset) && + BitUtil::GetBit(filter_data, position_with_offset)) { + builder.UnsafeAppend(position); + } + ++position; + ++position_with_offset; + } + } else { + position += and_block.length; + position_with_offset += and_block.length; + } + } + } else { + BitBlockCounter is_valid_counter(filter_is_valid, filter.offset, filter.length); + while (position < filter.length) { + // true OR NOT valid + BitBlockCount or_not_block = filter_counter.NextOrNotWord(); + RETURN_NOT_OK(builder.Reserve(or_not_block.popcount)); + + // If the values are all valid and the or_not_block is full, then we + // can infer that all the values are true and skip the bit checking + BitBlockCount is_valid_block = is_valid_counter.NextWord(); + + if (or_not_block.AllSet() && is_valid_block.AllSet()) { + // All the values are selected and non-null + for (int64_t i = 0; i < or_not_block.length; ++i) { + builder.UnsafeAppend(position++); + } + position_with_offset += or_not_block.length; + } else { + // Some of the values are false or null + for (int64_t i = 0; i < or_not_block.length; ++i) { + if (BitUtil::GetBit(filter_is_valid, position_with_offset)) { + if (BitUtil::GetBit(filter_data, position_with_offset)) { + builder.UnsafeAppend(position); + } + } else { + // Null slot, so append a null + builder.UnsafeAppendNull(); + } + ++position; + ++position_with_offset; + } + } + } + } + } else { + // The filter has no nulls, so we need only look for true values + BitBlockCount current_block = data_counter.NextWord(); + while (position < filter.length) { + if (current_block.AllSet()) { + int64_t run_length = 0; + + // If we've found a all-true block, then we scan forward until we find + // a block that has some false values (or we reach the end) + while (current_block.length > 0 && current_block.AllSet()) { + run_length += current_block.length; + current_block = data_counter.NextWord(); + } + + // Append the consecutive run of indices + RETURN_NOT_OK(builder.Reserve(run_length)); + for (int64_t i = 0; i < run_length; ++i) { + builder.UnsafeAppend(position++); + } + position_with_offset += run_length; + // The current_block already computed, so advance to next loop + // iteration. + continue; + } else if (!current_block.NoneSet()) { + // Must do bitchecking on the current block + RETURN_NOT_OK(builder.Reserve(current_block.popcount)); + for (int64_t i = 0; i < current_block.length; ++i) { + if (BitUtil::GetBit(filter_data, position_with_offset)) { + builder.UnsafeAppend(position); + } + ++position; + ++position_with_offset; + } + } else { + position += current_block.length; + position_with_offset += current_block.length; + } + current_block = data_counter.NextWord(); + } + } + std::shared_ptr result; + RETURN_NOT_OK(builder.FinishInternal(&result)); + return result; +} + +Result> GetTakeIndices( + const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection, + MemoryPool* memory_pool) { + DCHECK_EQ(filter.type->id(), Type::BOOL); + if (filter.length <= std::numeric_limits::max()) { + return GetTakeIndicesImpl(filter, null_selection, memory_pool); + } else if (filter.length <= std::numeric_limits::max()) { + return GetTakeIndicesImpl(filter, null_selection, memory_pool); + } else { + // Arrays over 4 billion elements, not especially likely. + return Status::NotImplemented( + "Filter length exceeds UINT32_MAX, " + "consider a different strategy for selecting elements"); + } +} + +namespace { + +using FilterState = OptionsWrapper; +using TakeState = OptionsWrapper; + +Status PreallocateData(KernelContext* ctx, int64_t length, int bit_width, + bool allocate_validity, Datum* out) { + // Preallocate memory + ArrayData* out_arr = out->mutable_array(); + out_arr->length = length; + out_arr->buffers.resize(2); + + if (allocate_validity) { + ARROW_ASSIGN_OR_RAISE(out_arr->buffers[0], ctx->AllocateBitmap(length)); + } + if (bit_width == 1) { + ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->AllocateBitmap(length)); + } else { + ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->Allocate(length * bit_width / 8)); + } + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// Implement optimized take for primitive types from boolean to 1/2/4/8-byte +// C-type based types. Use common implementation for every byte width and only +// generate code for unsigned integer indices, since after boundschecking to +// check for negative numbers in the indices we can safely reinterpret_cast +// signed integers as unsigned. + +/// \brief The Take implementation for primitive (fixed-width) types does not +/// use the logical Arrow type but rather the physical C type. This way we +/// only generate one take function for each byte width. +/// +/// This function assumes that the indices have been boundschecked. +template +struct PrimitiveTakeImpl { + static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices, + Datum* out_datum) { + auto values_data = reinterpret_cast(values.data); + auto values_is_valid = values.is_valid; + auto values_offset = values.offset; + + auto indices_data = reinterpret_cast(indices.data); + auto indices_is_valid = indices.is_valid; + auto indices_offset = indices.offset; + + ArrayData* out_arr = out_datum->mutable_array(); + auto out = out_arr->GetMutableValues(1); + auto out_is_valid = out_arr->buffers[0]->mutable_data(); + auto out_offset = out_arr->offset; + + // If either the values or indices have nulls, we preemptively zero out the + // out validity bitmap so that we don't have to use ClearBit in each + // iteration for nulls. + if (values.null_count != 0 || indices.null_count != 0) { + BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false); + } + + OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, + indices.length); + int64_t position = 0; + int64_t valid_count = 0; + while (position < indices.length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (values.null_count == 0) { + // Values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither values nor index nulls + BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + out[position] = values_data[indices_data[position]]; + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some indices but not all are null + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { + // index is not null + BitUtil::SetBit(out_is_valid, out_offset + position); + out[position] = values_data[indices_data[position]]; + } else { + out[position] = ValueCType{}; + } + ++position; + } + } else { + memset(out + position, 0, sizeof(ValueCType) * block.length); + position += block.length; + } + } else { + // Values have nulls, so we must do random access into the values bitmap + if (block.popcount == block.length) { + // Faster path: indices are not null but values may be + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(values_is_valid, + values_offset + indices_data[position])) { + // value is not null + out[position] = values_data[indices_data[position]]; + BitUtil::SetBit(out_is_valid, out_offset + position); + ++valid_count; + } else { + out[position] = ValueCType{}; + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_is_valid, indices_offset + position) && + BitUtil::GetBit(values_is_valid, + values_offset + indices_data[position])) { + // index is not null && value is not null + out[position] = values_data[indices_data[position]]; + BitUtil::SetBit(out_is_valid, out_offset + position); + ++valid_count; + } else { + out[position] = ValueCType{}; + } + ++position; + } + } else { + memset(out + position, 0, sizeof(ValueCType) * block.length); + position += block.length; + } + } + } + out_arr->null_count = out_arr->length - valid_count; + } +}; + +template +struct BooleanTakeImpl { + static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices, + Datum* out_datum) { + const uint8_t* values_data = values.data; + auto values_is_valid = values.is_valid; + auto values_offset = values.offset; + + auto indices_data = reinterpret_cast(indices.data); + auto indices_is_valid = indices.is_valid; + auto indices_offset = indices.offset; + + ArrayData* out_arr = out_datum->mutable_array(); + auto out = out_arr->buffers[1]->mutable_data(); + auto out_is_valid = out_arr->buffers[0]->mutable_data(); + auto out_offset = out_arr->offset; + + // If either the values or indices have nulls, we preemptively zero out the + // out validity bitmap so that we don't have to use ClearBit in each + // iteration for nulls. + if (values.null_count != 0 || indices.null_count != 0) { + BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false); + } + // Avoid uninitialized data in values array + BitUtil::SetBitsTo(out, out_offset, indices.length, false); + + auto PlaceDataBit = [&](int64_t loc, IndexCType index) { + BitUtil::SetBitTo(out, out_offset + loc, + BitUtil::GetBit(values_data, values_offset + index)); + }; + + OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, + indices.length); + int64_t position = 0; + int64_t valid_count = 0; + while (position < indices.length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (values.null_count == 0) { + // Values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither values nor index nulls + BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + PlaceDataBit(position, indices_data[position]); + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { + // index is not null + BitUtil::SetBit(out_is_valid, out_offset + position); + PlaceDataBit(position, indices_data[position]); + } + ++position; + } + } else { + position += block.length; + } + } else { + // Values have nulls, so we must do random access into the values bitmap + if (block.popcount == block.length) { + // Faster path: indices are not null but values may be + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(values_is_valid, + values_offset + indices_data[position])) { + // value is not null + BitUtil::SetBit(out_is_valid, out_offset + position); + PlaceDataBit(position, indices_data[position]); + ++valid_count; + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) { + // index is not null + if (BitUtil::GetBit(values_is_valid, + values_offset + indices_data[position])) { + // value is not null + PlaceDataBit(position, indices_data[position]); + BitUtil::SetBit(out_is_valid, out_offset + position); + ++valid_count; + } + } + ++position; + } + } else { + position += block.length; + } + } + } + out_arr->null_count = out_arr->length - valid_count; + } +}; + +template