Skip to content

Commit

Permalink
Implement groupby collect_set (#7420)
Browse files Browse the repository at this point in the history
This partially addresses #2973.

This PR implements groupby `collect_set` aggregation. The idea of this PR is to simply apply `drop_list_duplicates` (#7528) to the result generated by groupby `collect_list`, obtaining collect lists without duplicate entries.

Examples:
```
keys = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
vals = {10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31};

keys_output = {1, 2, 3};
vals_output = {{10, 11}, {20, 21}, {30, 31, 32, 33}};
```

In this PR, a simple, incomplete Python binding for `collect_set` has been added, and no Java binding is implemented yet. Complete bindings for those Python/Java sides need to be implemented later in some other separate PRs.

Authors:
  - Nghia Truong (@ttnghia)

Approvers:
  - AJ Schmidt (@ajschmidt8)
  - Karthikeyan (@karthikeyann)
  - Keith Kraus (@kkraus14)
  - Jason Lowe (@jlowe)
  - Ashwin Srinath (@shwina)

URL: #7420
  • Loading branch information
ttnghia authored Mar 23, 2021
1 parent 5cd90a0 commit 30e493c
Show file tree
Hide file tree
Showing 16 changed files with 459 additions and 126 deletions.
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ test:
- test -f $PREFIX/include/cudf/join.hpp
- test -f $PREFIX/include/cudf/lists/detail/concatenate.hpp
- test -f $PREFIX/include/cudf/lists/detail/copying.hpp
- test -f $PREFIX/include/cudf/lists/detail/drop_list_duplicates.hpp
- test -f $PREFIX/include/cudf/lists/detail/sorting.hpp
- test -f $PREFIX/include/cudf/lists/count_elements.hpp
- test -f $PREFIX/include/cudf/lists/explode.hpp
Expand Down
28 changes: 23 additions & 5 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -74,7 +74,8 @@ class aggregation {
NUNIQUE, ///< count number of unique elements
NTH_ELEMENT, ///< get the nth element
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
COLLECT, ///< collect values into a list
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
Expand Down Expand Up @@ -205,18 +206,35 @@ std::unique_ptr<aggregation> make_nth_element_aggregation(
std::unique_ptr<aggregation> make_row_number_aggregation();

/**
* @brief Factory to create a COLLECT aggregation
* @brief Factory to create a COLLECT_LIST aggregation
*
* `COLLECT` returns a list column of all included elements in the group/series.
* `COLLECT_LIST` returns a list column of all included elements in the group/series.
*
* If `null_handling` is set to `EXCLUDE`, null elements are dropped from each
* of the list rows.
*
* @param null_handling Indicates whether to include/exclude nulls in list elements.
*/
std::unique_ptr<aggregation> make_collect_aggregation(
std::unique_ptr<aggregation> make_collect_list_aggregation(
null_policy null_handling = null_policy::INCLUDE);

/**
* @brief Factory to create a COLLECT_SET aggregation
*
* `COLLECT_SET` returns a lists column of all included elements in the group/series. Within each
* list, the duplicated entries are dropped out such that each entry appears only once.
*
* If `null_handling` is set to `EXCLUDE`, null elements are dropped from each
* of the list rows.
*
* @param null_handling Indicates whether to include/exclude nulls during collection
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal
*/
std::unique_ptr<aggregation> make_collect_set_aggregation(
null_policy null_handling = null_policy::INCLUDE,
null_equality null_equal = null_equality::EQUAL);

/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset);

Expand Down
48 changes: 41 additions & 7 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -320,11 +320,11 @@ struct udf_aggregation final : derived_aggregation<udf_aggregation> {
};

/**
* @brief Derived aggregation class for specifying COLLECT aggregation
* @brief Derived aggregation class for specifying COLLECT_LIST aggregation
*/
struct collect_list_aggregation final : derived_aggregation<nunique_aggregation> {
explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE)
: derived_aggregation{COLLECT}, _null_handling{null_handling}
: derived_aggregation{COLLECT_LIST}, _null_handling{null_handling}
{
}
null_policy _null_handling; ///< include or exclude nulls
Expand All @@ -340,6 +340,32 @@ struct collect_list_aggregation final : derived_aggregation<nunique_aggregation>
size_t hash_impl() const { return std::hash<int>{}(static_cast<int>(_null_handling)); }
};

/**
* @brief Derived aggregation class for specifying COLLECT_SET aggregation
*/
struct collect_set_aggregation final : derived_aggregation<collect_set_aggregation> {
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality null_equal = null_equality::EQUAL)
: derived_aggregation{COLLECT_SET}, _null_handling{null_handling}, _null_equal(null_equal)
{
}
null_policy _null_handling; ///< include or exclude nulls
null_equality _null_equal; ///< whether to consider nulls as equal values

protected:
friend class derived_aggregation<collect_set_aggregation>;

bool operator==(collect_set_aggregation const& other) const
{
return _null_handling == other._null_handling && _null_equal == other._null_equal;
}

size_t hash_impl() const
{
return std::hash<int>{}(static_cast<int>(_null_handling) ^ static_cast<int>(_null_equal));
}
};

/**
* @brief Sentinel value used for `ARGMAX` aggregation.
*
Expand Down Expand Up @@ -514,9 +540,15 @@ struct target_type_impl<Source, aggregation::ROW_NUMBER> {
using type = cudf::size_type;
};

// Always use list for COLLECT
// Always use list for COLLECT_LIST
template <typename Source>
struct target_type_impl<Source, aggregation::COLLECT_LIST> {
using type = cudf::list_view;
};

// Always use list for COLLECT_SET
template <typename Source>
struct target_type_impl<Source, aggregation::COLLECT> {
struct target_type_impl<Source, aggregation::COLLECT_SET> {
using type = cudf::list_view;
};

Expand Down Expand Up @@ -617,8 +649,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin
return f.template operator()<aggregation::NTH_ELEMENT>(std::forward<Ts>(args)...);
case aggregation::ROW_NUMBER:
return f.template operator()<aggregation::ROW_NUMBER>(std::forward<Ts>(args)...);
case aggregation::COLLECT:
return f.template operator()<aggregation::COLLECT>(std::forward<Ts>(args)...);
case aggregation::COLLECT_LIST:
return f.template operator()<aggregation::COLLECT_LIST>(std::forward<Ts>(args)...);
case aggregation::COLLECT_SET:
return f.template operator()<aggregation::COLLECT_SET>(std::forward<Ts>(args)...);
case aggregation::LEAD:
return f.template operator()<aggregation::LEAD>(std::forward<Ts>(args)...);
case aggregation::LAG:
Expand Down
38 changes: 38 additions & 0 deletions cpp/include/cudf/lists/detail/drop_list_duplicates.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/lists/lists_column_view.hpp>

#include <rmm/cuda_stream_view.hpp>

namespace cudf {
namespace lists {
namespace detail {

/**
* @copydoc cudf::lists::drop_list_duplicates
*
* @param stream CUDA stream used for device memory operations and kernel launches.
*/
std::unique_ptr<column> drop_list_duplicates(
lists_column_view const& lists_column,
null_equality nulls_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
} // namespace detail
} // namespace lists
} // namespace cudf
12 changes: 9 additions & 3 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -125,11 +125,17 @@ std::unique_ptr<aggregation> make_row_number_aggregation()
{
return std::make_unique<aggregation>(aggregation::ROW_NUMBER);
}
/// Factory to create a COLLECT aggregation
std::unique_ptr<aggregation> make_collect_aggregation(null_policy null_handling)
/// Factory to create a COLLECT_LIST aggregation
std::unique_ptr<aggregation> make_collect_list_aggregation(null_policy null_handling)
{
return std::make_unique<detail::collect_list_aggregation>(null_handling);
}
/// Factory to create a COLLECT_SET aggregation
std::unique_ptr<aggregation> make_collect_set_aggregation(null_policy null_handling,
null_equality null_equal)
{
return std::make_unique<detail::collect_set_aggregation>(null_handling, null_equal);
}
/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset)
{
Expand Down
27 changes: 23 additions & 4 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@

#include <cudf/aggregation.hpp>
#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/aggregation/result_cache.hpp>
#include <cudf/detail/binaryop.hpp>
#include <cudf/detail/gather.hpp>
#include <cudf/detail/groupby.hpp>
#include <cudf/detail/groupby/sort_helper.hpp>
#include <cudf/detail/unary.hpp>
#include <cudf/groupby.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/table/table.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
Expand Down Expand Up @@ -57,6 +56,7 @@ struct aggregrate_result_functor final : store_result_functor {
template <aggregation::Kind k>
void operator()(aggregation const& agg)
{
CUDF_FAIL("Unsupported aggregation.");
}
};

Expand Down Expand Up @@ -347,12 +347,12 @@ void aggregrate_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation
}

template <>
void aggregrate_result_functor::operator()<aggregation::COLLECT>(aggregation const& agg)
void aggregrate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation const& agg)
{
auto null_handling =
static_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT aggregation.");
"null exclusion is not supported on groupby COLLECT_LIST aggregation.");

if (cache.has_result(col_idx, agg)) return;

Expand All @@ -362,6 +362,25 @@ void aggregrate_result_functor::operator()<aggregation::COLLECT>(aggregation con
cache.add_result(col_idx, agg, std::move(result));
};

template <>
void aggregrate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation const& agg)
{
auto const null_handling =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT_SET aggregation.");

if (cache.has_result(col_idx, agg)) { return; }

auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(), helper.num_groups(), stream, mr);
auto const nulls_equal =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_equal;
cache.add_result(col_idx,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, stream, mr));
};
} // namespace detail

// Sort-based groupby
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/lists/drop_list_duplicates.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ void generate_offsets(size_type num_entries,
return offsets[i - prefix_sum_empty_lists[i]];
});
}
} // anonymous namespace

/**
* @copydoc cudf::lists::drop_list_duplicates
*
Expand Down Expand Up @@ -276,7 +278,6 @@ std::unique_ptr<column> drop_list_duplicates(lists_column_view const& lists_colu
cudf::detail::copy_bitmask(lists_column.parent(), stream, mr));
}

} // anonymous namespace
} // namespace detail

/**
Expand Down
19 changes: 10 additions & 9 deletions cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ template <typename InputType,
std::enable_if_t<!std::is_same<InputType, cudf::string_view>::value and
!(op == aggregation::COUNT_VALID || op == aggregation::COUNT_ALL ||
op == aggregation::ROW_NUMBER || op == aggregation::LEAD ||
op == aggregation::LAG || op == aggregation::COLLECT)>* = nullptr>
op == aggregation::LAG || op == aggregation::COLLECT_LIST)>* = nullptr>
bool __device__ process_rolling_window(column_device_view input,
column_device_view ignored_default_outputs,
mutable_column_device_view output,
Expand Down Expand Up @@ -814,7 +814,7 @@ struct rolling_window_launcher {
typename PrecedingWindowIterator,
typename FollowingWindowIterator>
std::enable_if_t<!(op == aggregation::MEAN || op == aggregation::LEAD || op == aggregation::LAG ||
op == aggregation::COLLECT),
op == aggregation::COLLECT_LIST),
std::unique_ptr<column>>
operator()(column_view const& input,
column_view const& default_outputs,
Expand Down Expand Up @@ -897,11 +897,11 @@ struct rolling_window_launcher {
}

/**
* @brief Creates the offsets child of the result of the `COLLECT` window aggregation
* @brief Creates the offsets child of the result of the `COLLECT_LIST` window aggregation
*
* Given the input column, the preceding/following window bounds, and `min_periods`,
* the sizes of each list row may be computed. These values can then be used to
* calculate the offsets for the result of `COLLECT`.
* calculate the offsets for the result of `COLLECT_LIST`.
*
* Note: If `min_periods` exceeds the number of observations for a window, the size
* is set to `0` (since the result is `null`).
Expand Down Expand Up @@ -945,7 +945,7 @@ struct rolling_window_launcher {
}

/**
* @brief Generate mapping of each row in the COLLECT result's child column
* @brief Generate mapping of each row in the COLLECT_LIST result's child column
* to the index of the row it belongs to.
*
* If
Expand Down Expand Up @@ -1030,7 +1030,7 @@ struct rolling_window_launcher {

/**
* @brief Create gather map to generate the child column of the result of
* the `COLLECT` window aggregation.
* the `COLLECT_LIST` window aggregation.
*/
template <typename PrecedingIter>
std::unique_ptr<column> create_collect_gather_map(column_view const& child_offsets,
Expand Down Expand Up @@ -1064,7 +1064,7 @@ struct rolling_window_launcher {
}

/**
* @brief Count null entries in result of COLLECT.
* @brief Count null entries in result of COLLECT_LIST.
*/
size_type count_child_nulls(column_view const& input,
std::unique_ptr<column> const& gather_map,
Expand Down Expand Up @@ -1139,7 +1139,7 @@ struct rolling_window_launcher {
}

template <aggregation::Kind op, typename PrecedingIter, typename FollowingIter>
std::enable_if_t<(op == aggregation::COLLECT), std::unique_ptr<column>> operator()(
std::enable_if_t<(op == aggregation::COLLECT_LIST), std::unique_ptr<column>> operator()(
column_view const& input,
column_view const& default_outputs,
PrecedingIter preceding_begin_raw,
Expand All @@ -1150,7 +1150,7 @@ struct rolling_window_launcher {
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(default_outputs.is_empty(),
"COLLECT window function does not support default values.");
"COLLECT_LIST window function does not support default values.");

if (input.is_empty()) return empty_like(input);

Expand Down Expand Up @@ -1370,6 +1370,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
auto input_col = cudf::is_dictionary(input.type())
? dictionary_column_view(input).get_indices_annotated()
: input;

auto output = cudf::type_dispatcher(input_col.type(),
dispatch_rolling{},
input_col,
Expand Down
Loading

0 comments on commit 30e493c

Please sign in to comment.