From 30e493cefb8081dff29271f811ccc6c1f2c5e64c Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Tue, 23 Mar 2021 13:07:25 -0600 Subject: [PATCH] Implement groupby collect_set (#7420) This partially addresses #2973. This PR implements groupby `collect_set` aggregation. The idea of this PR is to simply apply `drop_list_duplicates` (https://github.com/rapidsai/cudf/pull/7528) to the result generated by groupby `collect_list`, obtaining collect lists without duplicate entries. Examples: ``` keys = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; vals = {10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31}; keys_output = {1, 2, 3}; vals_output = {{10, 11}, {20, 21}, {30, 31, 32, 33}}; ``` In this PR, a simple, incomplete Python binding for `collect_set` has been added, and no Java binding is implemented yet. Complete bindings for those Python/Java sides need to be implemented later in some other separate PRs. Authors: - Nghia Truong (@ttnghia) Approvers: - AJ Schmidt (@ajschmidt8) - Karthikeyan (@karthikeyann) - Keith Kraus (@kkraus14) - Jason Lowe (@jlowe) - Ashwin Srinath (@shwina) URL: https://github.com/rapidsai/cudf/pull/7420 --- conda/recipes/libcudf/meta.yaml | 1 + cpp/include/cudf/aggregation.hpp | 28 ++- .../cudf/detail/aggregation/aggregation.hpp | 48 ++++- .../lists/detail/drop_list_duplicates.hpp | 38 ++++ cpp/src/aggregation/aggregation.cpp | 12 +- cpp/src/groupby/sort/aggregate.cpp | 27 ++- cpp/src/lists/drop_list_duplicates.cu | 3 +- cpp/src/rolling/rolling_detail.cuh | 19 +- cpp/src/rolling/rolling_detail.hpp | 12 +- cpp/tests/CMakeLists.txt | 1 + cpp/tests/groupby/collect_set_test.cpp | 203 ++++++++++++++++++ cpp/tests/groupby/group_collect_test.cpp | 29 +-- cpp/tests/rolling/collect_list_test.cpp | 146 ++++++------- java/src/main/native/src/AggregationJni.cpp | 2 +- python/cudf/cudf/_lib/aggregation.pyx | 9 +- python/cudf/cudf/_lib/cpp/aggregation.pxd | 7 +- 16 files changed, 459 insertions(+), 126 deletions(-) create mode 100644 cpp/include/cudf/lists/detail/drop_list_duplicates.hpp create mode 100644 cpp/tests/groupby/collect_set_test.cpp diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index 5657d21889f..ee73ea9e320 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -132,6 +132,7 @@ test: - test -f $PREFIX/include/cudf/join.hpp - test -f $PREFIX/include/cudf/lists/detail/concatenate.hpp - test -f $PREFIX/include/cudf/lists/detail/copying.hpp + - test -f $PREFIX/include/cudf/lists/detail/drop_list_duplicates.hpp - test -f $PREFIX/include/cudf/lists/detail/sorting.hpp - test -f $PREFIX/include/cudf/lists/count_elements.hpp - test -f $PREFIX/include/cudf/lists/explode.hpp diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index a81b6ebc8a1..3c454c85720 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,7 +74,8 @@ class aggregation { NUNIQUE, ///< count number of unique elements NTH_ELEMENT, ///< get the nth element ROW_NUMBER, ///< get row-number of current index (relative to rolling window) - COLLECT, ///< collect values into a list + COLLECT_LIST, ///< collect values into a list + COLLECT_SET, ///< collect values into a list without duplicate entries LEAD, ///< window function, accesses row at specified offset following current row LAG, ///< window function, accesses row at specified offset preceding current row PTX, ///< PTX UDF based reduction @@ -205,18 +206,35 @@ std::unique_ptr make_nth_element_aggregation( std::unique_ptr make_row_number_aggregation(); /** - * @brief Factory to create a COLLECT aggregation + * @brief Factory to create a COLLECT_LIST aggregation * - * `COLLECT` returns a list column of all included elements in the group/series. + * `COLLECT_LIST` returns a list column of all included elements in the group/series. * * If `null_handling` is set to `EXCLUDE`, null elements are dropped from each * of the list rows. * * @param null_handling Indicates whether to include/exclude nulls in list elements. */ -std::unique_ptr make_collect_aggregation( +std::unique_ptr make_collect_list_aggregation( null_policy null_handling = null_policy::INCLUDE); +/** + * @brief Factory to create a COLLECT_SET aggregation + * + * `COLLECT_SET` returns a lists column of all included elements in the group/series. Within each + * list, the duplicated entries are dropped out such that each entry appears only once. + * + * If `null_handling` is set to `EXCLUDE`, null elements are dropped from each + * of the list rows. + * + * @param null_handling Indicates whether to include/exclude nulls during collection + * @param nulls_equal Flag to specify whether null entries within each list should be considered + * equal + */ +std::unique_ptr make_collect_set_aggregation( + null_policy null_handling = null_policy::INCLUDE, + null_equality null_equal = null_equality::EQUAL); + /// Factory to create a LAG aggregation std::unique_ptr make_lag_aggregation(size_type offset); diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 1a4847dad12..18bef301e03 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -320,11 +320,11 @@ struct udf_aggregation final : derived_aggregation { }; /** - * @brief Derived aggregation class for specifying COLLECT aggregation + * @brief Derived aggregation class for specifying COLLECT_LIST aggregation */ struct collect_list_aggregation final : derived_aggregation { explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE) - : derived_aggregation{COLLECT}, _null_handling{null_handling} + : derived_aggregation{COLLECT_LIST}, _null_handling{null_handling} { } null_policy _null_handling; ///< include or exclude nulls @@ -340,6 +340,32 @@ struct collect_list_aggregation final : derived_aggregation size_t hash_impl() const { return std::hash{}(static_cast(_null_handling)); } }; +/** + * @brief Derived aggregation class for specifying COLLECT_SET aggregation + */ +struct collect_set_aggregation final : derived_aggregation { + explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE, + null_equality null_equal = null_equality::EQUAL) + : derived_aggregation{COLLECT_SET}, _null_handling{null_handling}, _null_equal(null_equal) + { + } + null_policy _null_handling; ///< include or exclude nulls + null_equality _null_equal; ///< whether to consider nulls as equal values + + protected: + friend class derived_aggregation; + + bool operator==(collect_set_aggregation const& other) const + { + return _null_handling == other._null_handling && _null_equal == other._null_equal; + } + + size_t hash_impl() const + { + return std::hash{}(static_cast(_null_handling) ^ static_cast(_null_equal)); + } +}; + /** * @brief Sentinel value used for `ARGMAX` aggregation. * @@ -514,9 +540,15 @@ struct target_type_impl { using type = cudf::size_type; }; -// Always use list for COLLECT +// Always use list for COLLECT_LIST +template +struct target_type_impl { + using type = cudf::list_view; +}; + +// Always use list for COLLECT_SET template -struct target_type_impl { +struct target_type_impl { using type = cudf::list_view; }; @@ -617,8 +649,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin return f.template operator()(std::forward(args)...); case aggregation::ROW_NUMBER: return f.template operator()(std::forward(args)...); - case aggregation::COLLECT: - return f.template operator()(std::forward(args)...); + case aggregation::COLLECT_LIST: + return f.template operator()(std::forward(args)...); + case aggregation::COLLECT_SET: + return f.template operator()(std::forward(args)...); case aggregation::LEAD: return f.template operator()(std::forward(args)...); case aggregation::LAG: diff --git a/cpp/include/cudf/lists/detail/drop_list_duplicates.hpp b/cpp/include/cudf/lists/detail/drop_list_duplicates.hpp new file mode 100644 index 00000000000..ba3e1d17d7f --- /dev/null +++ b/cpp/include/cudf/lists/detail/drop_list_duplicates.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +namespace cudf { +namespace lists { +namespace detail { + +/** + * @copydoc cudf::lists::drop_list_duplicates + * + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr drop_list_duplicates( + lists_column_view const& lists_column, + null_equality nulls_equal, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +} // namespace detail +} // namespace lists +} // namespace cudf diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index 04dc8776d20..33c19617308 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -125,11 +125,17 @@ std::unique_ptr make_row_number_aggregation() { return std::make_unique(aggregation::ROW_NUMBER); } -/// Factory to create a COLLECT aggregation -std::unique_ptr make_collect_aggregation(null_policy null_handling) +/// Factory to create a COLLECT_LIST aggregation +std::unique_ptr make_collect_list_aggregation(null_policy null_handling) { return std::make_unique(null_handling); } +/// Factory to create a COLLECT_SET aggregation +std::unique_ptr make_collect_set_aggregation(null_policy null_handling, + null_equality null_equal) +{ + return std::make_unique(null_handling, null_equal); +} /// Factory to create a LAG aggregation std::unique_ptr make_lag_aggregation(size_type offset) { diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index ace25820ac9..b171b19413b 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -20,16 +20,15 @@ #include #include -#include #include #include #include #include #include -#include #include #include #include +#include #include #include #include @@ -57,6 +56,7 @@ struct aggregrate_result_functor final : store_result_functor { template void operator()(aggregation const& agg) { + CUDF_FAIL("Unsupported aggregation."); } }; @@ -347,12 +347,12 @@ void aggregrate_result_functor::operator()(aggregation } template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregrate_result_functor::operator()(aggregation const& agg) { auto null_handling = static_cast(agg)._null_handling; CUDF_EXPECTS(null_handling == null_policy::INCLUDE, - "null exclusion is not supported on groupby COLLECT aggregation."); + "null exclusion is not supported on groupby COLLECT_LIST aggregation."); if (cache.has_result(col_idx, agg)) return; @@ -362,6 +362,25 @@ void aggregrate_result_functor::operator()(aggregation con cache.add_result(col_idx, agg, std::move(result)); }; +template <> +void aggregrate_result_functor::operator()(aggregation const& agg) +{ + auto const null_handling = + static_cast(agg)._null_handling; + CUDF_EXPECTS(null_handling == null_policy::INCLUDE, + "null exclusion is not supported on groupby COLLECT_SET aggregation."); + + if (cache.has_result(col_idx, agg)) { return; } + + auto const collect_result = detail::group_collect( + get_grouped_values(), helper.group_offsets(), helper.num_groups(), stream, mr); + auto const nulls_equal = + static_cast(agg)._null_equal; + cache.add_result(col_idx, + agg, + lists::detail::drop_list_duplicates( + lists_column_view(collect_result->view()), nulls_equal, stream, mr)); +}; } // namespace detail // Sort-based groupby diff --git a/cpp/src/lists/drop_list_duplicates.cu b/cpp/src/lists/drop_list_duplicates.cu index 1eb105d296d..529b7489c35 100644 --- a/cpp/src/lists/drop_list_duplicates.cu +++ b/cpp/src/lists/drop_list_duplicates.cu @@ -225,6 +225,8 @@ void generate_offsets(size_type num_entries, return offsets[i - prefix_sum_empty_lists[i]]; }); } +} // anonymous namespace + /** * @copydoc cudf::lists::drop_list_duplicates * @@ -276,7 +278,6 @@ std::unique_ptr drop_list_duplicates(lists_column_view const& lists_colu cudf::detail::copy_bitmask(lists_column.parent(), stream, mr)); } -} // anonymous namespace } // namespace detail /** diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index dcc48aafb39..42562507fa9 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -315,7 +315,7 @@ template ::value and !(op == aggregation::COUNT_VALID || op == aggregation::COUNT_ALL || op == aggregation::ROW_NUMBER || op == aggregation::LEAD || - op == aggregation::LAG || op == aggregation::COLLECT)>* = nullptr> + op == aggregation::LAG || op == aggregation::COLLECT_LIST)>* = nullptr> bool __device__ process_rolling_window(column_device_view input, column_device_view ignored_default_outputs, mutable_column_device_view output, @@ -814,7 +814,7 @@ struct rolling_window_launcher { typename PrecedingWindowIterator, typename FollowingWindowIterator> std::enable_if_t> operator()(column_view const& input, column_view const& default_outputs, @@ -897,11 +897,11 @@ struct rolling_window_launcher { } /** - * @brief Creates the offsets child of the result of the `COLLECT` window aggregation + * @brief Creates the offsets child of the result of the `COLLECT_LIST` window aggregation * * Given the input column, the preceding/following window bounds, and `min_periods`, * the sizes of each list row may be computed. These values can then be used to - * calculate the offsets for the result of `COLLECT`. + * calculate the offsets for the result of `COLLECT_LIST`. * * Note: If `min_periods` exceeds the number of observations for a window, the size * is set to `0` (since the result is `null`). @@ -945,7 +945,7 @@ struct rolling_window_launcher { } /** - * @brief Generate mapping of each row in the COLLECT result's child column + * @brief Generate mapping of each row in the COLLECT_LIST result's child column * to the index of the row it belongs to. * * If @@ -1030,7 +1030,7 @@ struct rolling_window_launcher { /** * @brief Create gather map to generate the child column of the result of - * the `COLLECT` window aggregation. + * the `COLLECT_LIST` window aggregation. */ template std::unique_ptr create_collect_gather_map(column_view const& child_offsets, @@ -1064,7 +1064,7 @@ struct rolling_window_launcher { } /** - * @brief Count null entries in result of COLLECT. + * @brief Count null entries in result of COLLECT_LIST. */ size_type count_child_nulls(column_view const& input, std::unique_ptr const& gather_map, @@ -1139,7 +1139,7 @@ struct rolling_window_launcher { } template - std::enable_if_t<(op == aggregation::COLLECT), std::unique_ptr> operator()( + std::enable_if_t<(op == aggregation::COLLECT_LIST), std::unique_ptr> operator()( column_view const& input, column_view const& default_outputs, PrecedingIter preceding_begin_raw, @@ -1150,7 +1150,7 @@ struct rolling_window_launcher { rmm::mr::device_memory_resource* mr) { CUDF_EXPECTS(default_outputs.is_empty(), - "COLLECT window function does not support default values."); + "COLLECT_LIST window function does not support default values."); if (input.is_empty()) return empty_like(input); @@ -1370,6 +1370,7 @@ std::unique_ptr rolling_window(column_view const& input, auto input_col = cudf::is_dictionary(input.type()) ? dictionary_column_view(input).get_indices_annotated() : input; + auto output = cudf::type_dispatcher(input_col.type(), dispatch_rolling{}, input_col, diff --git a/cpp/src/rolling/rolling_detail.hpp b/cpp/src/rolling/rolling_detail.hpp index d7fa92f1978..18bd0ea2217 100644 --- a/cpp/src/rolling/rolling_detail.hpp +++ b/cpp/src/rolling/rolling_detail.hpp @@ -41,7 +41,7 @@ static constexpr bool is_rolling_supported() (op == aggregation::SUM) or (op == aggregation::MIN) or (op == aggregation::MAX) or (op == aggregation::COUNT_VALID) or (op == aggregation::COUNT_ALL) or (op == aggregation::MEAN) or (op == aggregation::ROW_NUMBER) or (op == aggregation::LEAD) or - (op == aggregation::LAG) or (op == aggregation::COLLECT); + (op == aggregation::LAG) or (op == aggregation::COLLECT_LIST); constexpr bool is_valid_numeric_agg = (cudf::is_numeric() or cudf::is_duration() or @@ -54,23 +54,23 @@ static constexpr bool is_rolling_supported() return (op == aggregation::MIN) or (op == aggregation::MAX) or (op == aggregation::COUNT_VALID) or (op == aggregation::COUNT_ALL) or (op == aggregation::ROW_NUMBER) or (op == aggregation::LEAD) or - (op == aggregation::LAG) or (op == aggregation::COLLECT); + (op == aggregation::LAG) or (op == aggregation::COLLECT_LIST); } else if (cudf::is_fixed_point()) { return (op == aggregation::SUM) or (op == aggregation::MIN) or (op == aggregation::MAX) or (op == aggregation::COUNT_VALID) or (op == aggregation::COUNT_ALL) or (op == aggregation::ROW_NUMBER) or (op == aggregation::LEAD) or - (op == aggregation::LAG) or (op == aggregation::COLLECT); + (op == aggregation::LAG) or (op == aggregation::COLLECT_LIST); } else if (std::is_same()) { return (op == aggregation::MIN) or (op == aggregation::MAX) or (op == aggregation::COUNT_VALID) or (op == aggregation::COUNT_ALL) or - (op == aggregation::ROW_NUMBER) or (op == aggregation::COLLECT); + (op == aggregation::ROW_NUMBER) or (op == aggregation::COLLECT_LIST); } else if (std::is_same()) { return (op == aggregation::COUNT_VALID) or (op == aggregation::COUNT_ALL) or - (op == aggregation::ROW_NUMBER) or (op == aggregation::COLLECT); + (op == aggregation::ROW_NUMBER) or (op == aggregation::COLLECT_LIST); } else if (std::is_same()) { // TODO: Add support for COUNT_VALID, COUNT_ALL, ROW_NUMBER. - return op == aggregation::COLLECT; + return op == aggregation::COLLECT_LIST; } else { return false; } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 4b2d1e04ac5..8e92652d892 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -54,6 +54,7 @@ ConfigureTest(ERROR_TEST error/error_handling_test.cu) ################################################################################################### # - groupby tests --------------------------------------------------------------------------------- ConfigureTest(GROUPBY_TEST + groupby/collect_set_test.cpp groupby/groupby_groups_test.cpp groupby/group_argmin_test.cpp groupby/group_argmax_test.cpp diff --git a/cpp/tests/groupby/collect_set_test.cpp b/cpp/tests/groupby/collect_set_test.cpp new file mode 100644 index 00000000000..5303b8f4f61 --- /dev/null +++ b/cpp/tests/groupby/collect_set_test.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include + +#include + +namespace cudf { +namespace test { + +#define COL_K cudf::test::fixed_width_column_wrapper +#define COL_V cudf::test::fixed_width_column_wrapper +#define COL_S cudf::test::strings_column_wrapper +#define LCL_V cudf::test::lists_column_wrapper +#define LCL_S cudf::test::lists_column_wrapper +#define VALIDITY std::initializer_list +#define COLLECT_SET cudf::make_collect_set_aggregation() +#define COLLECT_SET_NULL_UNEQUAL \ + cudf::make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL) + +struct CollectSetTest : public cudf::test::BaseFixture { +}; + +template +struct CollectSetTypedTest : public cudf::test::BaseFixture { +}; + +using FixedWidthTypesNotBool = cudf::test::Concat; +TYPED_TEST_CASE(CollectSetTypedTest, FixedWidthTypesNotBool); + +TYPED_TEST(CollectSetTypedTest, ExceptionTests) +{ + std::vector agg_requests(1); + agg_requests[0].values = COL_V{{1, 2, 3, 4, 5, 6}, {true, false, true, false, true, false}}; + agg_requests[0].aggregations.push_back(cudf::make_collect_list_aggregation(null_policy::EXCLUDE)); + + // groupby cannot exclude nulls + groupby::groupby gby{table_view{{COL_K{1, 1, 2, 2, 3, 3}}}}; + EXPECT_THROW(gby.aggregate(agg_requests), cudf::logic_error); +} + +TYPED_TEST(CollectSetTypedTest, TrivialInput) +{ + // Empty input + // TODO: Enable this test after issue#7611 has been fixed + // test_single_agg(COL_K{}, COL_V{}, COL_K{}, COL_V{}, COLLECT_SET); + + // Single key input + { + COL_K keys{1}; + COL_V vals{10}; + COL_K keys_expected{1}; + LCL_V vals_expected{LCL_V{10}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); + } + + // Non-repeated keys + { + COL_K keys{2, 1}; + COL_V vals{20, 10}; + COL_K keys_expected{1, 2}; + LCL_V vals_expected{LCL_V{10}, LCL_V{20}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); + } +} + +TYPED_TEST(CollectSetTypedTest, TypicalInput) +{ + // Pre-sorted keys + { + COL_K keys{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + COL_V vals{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31}; + COL_K keys_expected{1, 2, 3}; + LCL_V vals_expected{{10, 11}, {20, 21}, {30, 31, 32, 33}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); + } + + // Expect the result keys to be sorted by sort-based groupby + { + COL_K keys{4, 1, 2, 4, 3, 3, 2, 1}; + COL_V vals{40, 10, 20, 40, 30, 30, 20, 11}; + COL_K keys_expected{1, 2, 3, 4}; + LCL_V vals_expected{{10, 11}, {20}, {30}, {40}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); + } +} + +// Keys and values columns are sliced columns +TYPED_TEST(CollectSetTypedTest, SlicedColumnsInput) +{ + COL_K keys_original{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + COL_V vals_original{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31}; + { + auto const keys = cudf::slice(keys_original, {0, 4})[0]; // { 1, 1, 1, 1 } + auto const vals = cudf::slice(vals_original, {0, 4})[0]; // { 10, 11, 10, 10 } + auto const keys_expected = COL_K{1}; + auto const vals_expected = LCL_V{{10, 11}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); + } + { + auto const keys = cudf::slice(keys_original, {2, 10})[0]; // { 1, 1, 2, 2, 2, 2, 3, 3 } + auto const vals = cudf::slice(vals_original, {2, 10})[0]; // { 10, 10, 20, 21, 21, 20, 30, 33 } + auto const keys_expected = COL_K{1, 2, 3}; + auto const vals_expected = LCL_V{{10}, {20, 21}, {30, 33}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); + } +} + +TEST_F(CollectSetTest, StringInput) +{ + COL_K keys{1, 2, 3, 3, 2, 1, 2, 1, 2, 1, 1, 1, 1}; + COL_S vals{ + "String 1, first", + "String 2, first", + "String 3, first", + "String 3, second", + "String 2, second", + "String 1, second", + "String 2, second", // repeated + "String 1, second", // repeated + "String 2, second", // repeated + "String 1, second", // repeated + "String 1, second", // repeated + "String 1, second", // repeated + "String 1, second" // repeated + }; + COL_K keys_expected{1, 2, 3}; + LCL_S vals_expected{{"String 1, first", "String 1, second"}, + {"String 2, first", "String 2, second"}, + {"String 3, first", "String 3, second"}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); +} + +TYPED_TEST(CollectSetTypedTest, CollectWithNulls) +{ + // Just use an arbitrary value to store null entries + // Using this alias variable will make the code look cleaner + constexpr int32_t null = 0; + + // Pre-sorted keys + { + COL_K keys{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + COL_V vals{{10, 10, null, null, 20, null, null, null, 30, 31, 30, 31}, + {true, true, false, false, true, false, false, false, true, true, true, true}}; + COL_K keys_expected{1, 2, 3}; + + // By default, nulls are consider equals, thus only one null is kept per key + LCL_V vals_expected{{{10, null}, VALIDITY{true, false}}, + {{20, null}, VALIDITY{true, false}}, + {{30, 31}, VALIDITY{true, true}}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); + + // All nulls per key are kept (nulls are put at the end of each list) + vals_expected = LCL_V{{{10, null, null}, VALIDITY{true, false, false}}, + {{20, null, null, null}, VALIDITY{true, false, false, false}}, + {{30, 31}, VALIDITY{true, true}}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET_NULL_UNEQUAL); + } + + // Expect the result keys to be sorted by sort-based groupby + { + COL_K keys{4, 1, 2, 4, 3, 3, 3, 3, 2, 1}; + COL_V vals{{40, 10, 20, 40, null, null, null, null, 21, null}, + {true, true, true, true, false, false, false, false, true, false}}; + COL_K keys_expected{1, 2, 3, 4}; + + // By default, nulls are consider equals, thus only one null is kept per key + LCL_V vals_expected{{{10, null}, VALIDITY{true, false}}, + {{20, 21}, VALIDITY{true, true}}, + {{null}, VALIDITY{false}}, + {{40}, VALIDITY{true}}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET); + + // All nulls per key are kept (nulls are put at the end of each list) + vals_expected = LCL_V{{{10, null}, VALIDITY{true, false}}, + {{20, 21}, VALIDITY{true, true}}, + {{null, null, null, null}, VALIDITY{false, false, false, false}}, + {{40}, VALIDITY{true}}}; + test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET_NULL_UNEQUAL); + } +} + +} // namespace test +} // namespace cudf diff --git a/cpp/tests/groupby/group_collect_test.cpp b/cpp/tests/groupby/group_collect_test.cpp index 9edd0a6932a..8a578ea0c0f 100644 --- a/cpp/tests/groupby/group_collect_test.cpp +++ b/cpp/tests/groupby/group_collect_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,15 +26,15 @@ namespace cudf { namespace test { template -struct groupby_collect_test : public cudf::test::BaseFixture { +struct groupby_collect_list_test : public cudf::test::BaseFixture { }; using FixedWidthTypesNotBool = cudf::test::Concat; -TYPED_TEST_CASE(groupby_collect_test, FixedWidthTypesNotBool); +TYPED_TEST_CASE(groupby_collect_list_test, FixedWidthTypesNotBool); -TYPED_TEST(groupby_collect_test, CollectWithoutNulls) +TYPED_TEST(groupby_collect_list_test, CollectWithoutNulls) { using K = int32_t; using V = TypeParam; @@ -45,11 +45,11 @@ TYPED_TEST(groupby_collect_test, CollectWithoutNulls) fixed_width_column_wrapper expect_keys{1, 2}; lists_column_wrapper expect_vals{{1, 2, 3}, {4, 5, 6}}; - auto agg = cudf::make_collect_aggregation(); + auto agg = cudf::make_collect_list_aggregation(); test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg)); } -TYPED_TEST(groupby_collect_test, CollectWithNulls) +TYPED_TEST(groupby_collect_list_test, CollectWithNulls) { using K = int32_t; using V = TypeParam; @@ -64,11 +64,11 @@ TYPED_TEST(groupby_collect_test, CollectWithNulls) lists_column_wrapper expect_vals{ {{1, 2}, validity.begin()}, {{3, 4}, validity.begin()}, {{5, 6}, validity.begin()}}; - auto agg = cudf::make_collect_aggregation(); + auto agg = cudf::make_collect_list_aggregation(); test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg)); } -TYPED_TEST(groupby_collect_test, CollectLists) +TYPED_TEST(groupby_collect_list_test, CollectLists) { using K = int32_t; using V = TypeParam; @@ -83,11 +83,11 @@ TYPED_TEST(groupby_collect_test, CollectLists) lists_column_wrapper expect_vals{ {{1, 2}, {3, 4}}, {{5, 6, 7}, LCW{}}, {{9, 10}, {11}}}; - auto agg = cudf::make_collect_aggregation(); + auto agg = cudf::make_collect_list_aggregation(); test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg)); } -TYPED_TEST(groupby_collect_test, dictionary) +TYPED_TEST(groupby_collect_list_test, dictionary) { using K = int32_t; using V = TypeParam; @@ -105,10 +105,11 @@ TYPED_TEST(groupby_collect_test, dictionary) 0, rmm::device_buffer{0}); - test_single_agg(keys, vals, expect_keys, expect_vals->view(), cudf::make_collect_aggregation()); + test_single_agg( + keys, vals, expect_keys, expect_vals->view(), cudf::make_collect_list_aggregation()); } -TYPED_TEST(groupby_collect_test, CollectFailsWithNullExclusion) +TYPED_TEST(groupby_collect_list_test, CollectFailsWithNullExclusion) { using K = int32_t; using V = TypeParam; @@ -121,10 +122,10 @@ TYPED_TEST(groupby_collect_test, CollectFailsWithNullExclusion) std::vector agg_requests(1); agg_requests[0].values = values; - agg_requests[0].aggregations.push_back(cudf::make_collect_aggregation(null_policy::EXCLUDE)); + agg_requests[0].aggregations.push_back(cudf::make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_EXPECT_THROW_MESSAGE(gby.aggregate(agg_requests), - "null exclusion is not supported on groupby COLLECT aggregation."); + "null exclusion is not supported on groupby COLLECT_LIST aggregation."); } } // namespace test diff --git a/cpp/tests/rolling/collect_list_test.cpp b/cpp/tests/rolling/collect_list_test.cpp index 6a3a80601d0..de179223d68 100644 --- a/cpp/tests/rolling/collect_list_test.cpp +++ b/cpp/tests/rolling/collect_list_test.cpp @@ -64,7 +64,7 @@ TYPED_TEST(TypedCollectListTest, BasicRollingWindow) static_cast(foll_column).size()); auto const result_column_based_window = - rolling_window(input_column, prev_column, foll_column, 1, make_collect_aggregation()); + rolling_window(input_column, prev_column, foll_column, 1, make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -79,11 +79,11 @@ TYPED_TEST(TypedCollectListTest, BasicRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); auto const result_fixed_window = - rolling_window(input_column, 2, 1, 1, make_collect_aggregation()); + rolling_window(input_column, 2, 1, 1, make_collect_list_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, 2, 1, 1, make_collect_aggregation(null_policy::EXCLUDE)); + rolling_window(input_column, 2, 1, 1, make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -104,7 +104,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputLists) static_cast(foll_column).size()); auto const result_column_based_window = - rolling_window(input_column, prev_column, foll_column, 0, make_collect_aggregation()); + rolling_window(input_column, prev_column, foll_column, 0, make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -120,7 +120,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputLists) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); auto const result_with_nulls_excluded = rolling_window( - input_column, prev_column, foll_column, 0, make_collect_aggregation(null_policy::EXCLUDE)); + input_column, prev_column, foll_column, 0, make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -138,7 +138,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputListsAtEnds) auto foll_column = fixed_width_column_wrapper{0, 1, 1, 1, 1, 0}; auto const result = - rolling_window(input_column, prev_column, foll_column, 0, make_collect_aggregation()); + rolling_window(input_column, prev_column, foll_column, 0, make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{{}, {0, 1, 2}, {1, 2, 3}, {2, 3, 4}, {3, 4, 5}, {}}.release(); @@ -146,7 +146,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputListsAtEnds) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); auto const result_with_nulls_excluded = rolling_window( - input_column, prev_column, foll_column, 0, make_collect_aggregation(null_policy::EXCLUDE)); + input_column, prev_column, foll_column, 0, make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -164,11 +164,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) auto const input_column = fixed_width_column_wrapper{0, 1, 2, 3, 4, 5}; auto const num_elements = static_cast(input_column).size(); - auto preceding = 2; - auto following = 1; - auto min_periods = 3; - auto const result = - rolling_window(input_column, preceding, following, min_periods, make_collect_aggregation()); + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window( + input_column, preceding, following, min_periods, make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {{}, {0, 1, 2}, {1, 2, 3}, {2, 3, 4}, {3, 4, 5}, {}}, @@ -183,7 +183,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -191,8 +191,8 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) following = 2; min_periods = 4; - auto result_2 = - rolling_window(input_column, preceding, following, min_periods, make_collect_aggregation()); + auto result_2 = rolling_window( + input_column, preceding, following, min_periods, make_collect_list_aggregation()); auto expected_result_2 = lists_column_wrapper{ {{}, {0, 1, 2, 3}, {1, 2, 3, 4}, {2, 3, 4, 5}, {}, {}}, cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { @@ -206,7 +206,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2_with_nulls_excluded->view()); @@ -228,11 +228,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) { // One result row at each end should be null. - auto preceding = 2; - auto following = 1; - auto min_periods = 3; - auto const result = - rolling_window(input_column, preceding, following, min_periods, make_collect_aggregation()); + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window( + input_column, preceding, following, min_periods, make_collect_list_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5}; auto expected_result_child_validity = std::vector{1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1}; @@ -265,7 +265,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); auto expected_result_child_values = std::vector{0, 2, 2, 3, 2, 3, 3, 5}; auto expected_result_child = fixed_width_column_wrapper( @@ -287,11 +287,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) { // First result row, and the last two result rows should be null. - auto preceding = 2; - auto following = 2; - auto min_periods = 4; - auto const result = - rolling_window(input_column, preceding, following, min_periods, make_collect_aggregation()); + auto preceding = 2; + auto following = 2; + auto min_periods = 4; + auto const result = rolling_window( + input_column, preceding, following, min_periods, make_collect_list_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5}; auto expected_result_child_validity = std::vector{1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1}; @@ -325,7 +325,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); auto expected_result_child_values = std::vector{0, 2, 3, 2, 3, 2, 3, 5}; auto expected_result_child = fixed_width_column_wrapper( @@ -358,11 +358,11 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) auto const input_column = strings_column_wrapper{"0", "1", "2", "3", "4", "5"}; auto const num_elements = static_cast(input_column).size(); - auto preceding = 2; - auto following = 1; - auto min_periods = 3; - auto const result = - rolling_window(input_column, preceding, following, min_periods, make_collect_aggregation()); + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window( + input_column, preceding, following, min_periods, make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {{}, {"0", "1", "2"}, {"1", "2", "3"}, {"2", "3", "4"}, {"3", "4", "5"}, {}}, @@ -377,7 +377,7 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -385,8 +385,8 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) following = 2; min_periods = 4; - auto result_2 = - rolling_window(input_column, preceding, following, min_periods, make_collect_aggregation()); + auto result_2 = rolling_window( + input_column, preceding, following, min_periods, make_collect_list_aggregation()); auto expected_result_2 = lists_column_wrapper{ {{}, {"0", "1", "2", "3"}, {"1", "2", "3", "4"}, {"2", "3", "4", "5"}, {}, {}}, cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { @@ -400,7 +400,7 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2_with_nulls_excluded->view()); @@ -421,11 +421,11 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) { // One result row at each end should be null. - auto preceding = 2; - auto following = 1; - auto min_periods = 3; - auto const result = - rolling_window(input_column, preceding, following, min_periods, make_collect_aggregation()); + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window( + input_column, preceding, following, min_periods, make_collect_list_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5}; auto expected_result_child = @@ -451,7 +451,7 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -459,11 +459,11 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) { // First result row, and the last two result rows should be null. - auto preceding = 2; - auto following = 2; - auto min_periods = 4; - auto const result = - rolling_window(input_column, preceding, following, min_periods, make_collect_aggregation()); + auto preceding = 2; + auto following = 2; + auto min_periods = 4; + auto const result = rolling_window( + input_column, preceding, following, min_periods, make_collect_list_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5}; auto expected_result_child = @@ -489,7 +489,7 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -515,7 +515,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindow) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {10, 11}, @@ -536,7 +536,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindow) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -563,7 +563,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindowWithNulls) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto expected_child = fixed_width_column_wrapper{ {10, 11, 10, 11, 12, 11, 12, 13, 12, 13, 14, 13, 14, 20, 21, 20, 21, 22, 21, 22, 23, 22, 23}, @@ -587,7 +587,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindowWithNulls) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); auto expected_child = fixed_width_column_wrapper{ 10, 10, 12, 12, 13, 12, 13, 14, 13, 14, 20, 20, 22, 22, 23, 22, 23}; @@ -627,7 +627,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindow) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {10, 11, 12, 13}, @@ -650,7 +650,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindow) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -678,7 +678,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNulls) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto null_at_0 = iterator_with_null_at(0); auto null_at_1 = iterator_with_null_at(1); @@ -705,7 +705,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNulls) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); // After null exclusion, `11`, `21`, and `null` should not appear. auto const expected_result_with_nulls_excluded = lists_column_wrapper{ @@ -744,7 +744,7 @@ TEST_F(CollectListTest, BasicGroupedTimeRangeRollingWindowOnStrings) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {"10", "11", "12", "13"}, @@ -767,7 +767,7 @@ TEST_F(CollectListTest, BasicGroupedTimeRangeRollingWindowOnStrings) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -793,7 +793,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNulls) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto null_at_0 = iterator_with_null_at(0); auto null_at_1 = iterator_with_null_at(1); @@ -821,7 +821,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNulls) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); // After null exclusion, `11`, `21`, and `null` should not appear. auto const expected_result_with_nulls_excluded = lists_column_wrapper{ @@ -868,7 +868,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindowOnStructs) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto expected_numeric_column = fixed_width_column_wrapper{ 10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13, 14, 10, 11, 12, @@ -898,7 +898,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindowOnStructs) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -928,7 +928,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithMinPeriods) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {{10, 11, 12, 13}, @@ -954,7 +954,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithMinPeriods) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -984,7 +984,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNullsAndMinPer preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto null_at_1 = iterator_with_null_at(1); @@ -1013,7 +1013,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNullsAndMinPer preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); // After null exclusion, `11`, `21`, and `null` should not appear. auto const expected_result_with_nulls_excluded = lists_column_wrapper{ @@ -1056,7 +1056,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithMinPeriods) preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {{"10", "11", "12", "13"}, @@ -1082,7 +1082,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithMinPeriods) preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -1110,7 +1110,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNullsAndMinPer preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto null_at_1 = iterator_with_null_at(1); @@ -1139,7 +1139,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNullsAndMinPer preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); // After null exclusion, `11`, `21`, and `null` should not appear. auto const expected_result_with_nulls_excluded = lists_column_wrapper{ @@ -1190,7 +1190,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe preceding, following, min_periods, - make_collect_aggregation()); + make_collect_list_aggregation()); auto expected_numeric_column = fixed_width_column_wrapper{ 10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13, 14, 10, 11, 12, 13, 14, 10, 11, 12, 13, 14}; @@ -1226,7 +1226,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe preceding, following, min_periods, - make_collect_aggregation(null_policy::EXCLUDE)); + make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index aae7cb493a8..c5184111edf 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -206,7 +206,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectAgg(JNIEnv cudf::jni::auto_set_device(env); cudf::null_policy policy = include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE; - std::unique_ptr ret = cudf::make_collect_aggregation(policy); + std::unique_ptr ret = cudf::make_collect_list_aggregation(policy); return reinterpret_cast(ret.release()); } CATCH_STD(env, 0); diff --git a/python/cudf/cudf/_lib/aggregation.pyx b/python/cudf/cudf/_lib/aggregation.pyx index 5c6801137ae..840f0c98987 100644 --- a/python/cudf/cudf/_lib/aggregation.pyx +++ b/python/cudf/cudf/_lib/aggregation.pyx @@ -50,6 +50,7 @@ class AggregationKind(Enum): NUNIQUE = libcudf_aggregation.aggregation.Kind.NUNIQUE NTH = libcudf_aggregation.aggregation.Kind.NTH_ELEMENT COLLECT = libcudf_aggregation.aggregation.Kind.COLLECT + COLLECT_SET = libcudf_aggregation.aggregation.Kind.COLLECT_SET PTX = libcudf_aggregation.aggregation.Kind.PTX CUDA = libcudf_aggregation.aggregation.Kind.CUDA @@ -241,7 +242,13 @@ cdef class _AggregationFactory: @classmethod def collect(cls): cdef Aggregation agg = Aggregation.__new__(Aggregation) - agg.c_obj = move(libcudf_aggregation.make_collect_aggregation()) + agg.c_obj = move(libcudf_aggregation.make_collect_list_aggregation()) + return agg + + @classmethod + def collect_set(cls): + cdef Aggregation agg = Aggregation.__new__(Aggregation) + agg.c_obj = move(libcudf_aggregation.make_collect_set_aggregation()) return agg @classmethod diff --git a/python/cudf/cudf/_lib/cpp/aggregation.pxd b/python/cudf/cudf/_lib/cpp/aggregation.pxd index 660db29f7a9..e9836c11361 100644 --- a/python/cudf/cudf/_lib/cpp/aggregation.pxd +++ b/python/cudf/cudf/_lib/cpp/aggregation.pxd @@ -34,7 +34,8 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: ARGMIN 'cudf::aggregation::ARGMIN' NUNIQUE 'cudf::aggregation::NUNIQUE' NTH_ELEMENT 'cudf::aggregation::NTH_ELEMENT' - COLLECT 'cudf::aggregation::COLLECT' + COLLECT 'cudf::aggregation::COLLECT_LIST' + COLLECT_SET 'cudf::aggregation::COLLECT_SET' PTX 'cudf::aggregation::PTX' CUDA 'cudf::aggregation::CUDA' Kind kind @@ -83,7 +84,9 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: size_type n ) except + - cdef unique_ptr[aggregation] make_collect_aggregation() except + + cdef unique_ptr[aggregation] make_collect_list_aggregation() except + + + cdef unique_ptr[aggregation] make_collect_set_aggregation() except + cdef unique_ptr[aggregation] make_udf_aggregation( udf_type type,