Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add host-bulk for_each for static_map #565

Merged
merged 23 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions include/cuco/detail/open_addressing/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,44 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void erase(InputIt first,
}
}

/**
* @brief For each key in the range [first, first + n), applies the function object `callback_op` to
* the copy of all corresponding matches found in the container.
*
* @tparam CGSize Number of threads in each CG
* @tparam BlockSize Number of threads in each block
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the `key_type` of the data structure
* @tparam CallbackOp Type of unary callback function object
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param first Beginning of the sequence of input elements
* @param n Number of input elements
* @param callback_op Function to call on every element found in the container
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
* @param ref Non-owning container device ref used to access the slot storage
*/
template <int32_t CGSize, int32_t BlockSize, typename InputIt, typename CallbackOp, typename Ref>
CUCO_KERNEL __launch_bounds__(BlockSize) void for_each_n(InputIt first,
cuco::detail::index_type n,
CallbackOp callback_op,
Ref ref)
{
auto const loop_stride = cuco::detail::grid_stride() / CGSize;
auto idx = cuco::detail::global_thread_id() / CGSize;

while (idx < n) {
typename std::iterator_traits<InputIt>::value_type const& key{*(first + idx)};
if constexpr (CGSize == 1) {
ref.for_each(key, callback_op);
} else {
auto const tile =
cooperative_groups::tiled_partition<CGSize>(cooperative_groups::this_thread_block());
ref.for_each(tile, key, callback_op);
}
idx += loop_stride;
}
}

/**
* @brief Indicates whether the keys in the range `[first, first + n)` are contained in the data
* structure if `pred` of the corresponding stencil returns true.
Expand Down
58 changes: 58 additions & 0 deletions include/cuco/detail/open_addressing/open_addressing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cuco/storage.cuh>
#include <cuco/utility/traits.hpp>

#include <cub/device/device_for.cuh>
#include <cub/device/device_select.cuh>
#include <cuda/atomic>
#include <thrust/iterator/constant_iterator.h>
Expand Down Expand Up @@ -681,6 +682,63 @@ class open_addressing_impl {
return output_begin + h_num_out;
}

/**
* @brief Asynchronously applies the given function object `callback_op` to the copy of every
* filled slot in the container
*
* @tparam CallbackOp Type of unary callback function object
*
* @param callback_op Function to call on every filled slot in the container
* @param stream CUDA stream used for this operation
*/
template <typename CallbackOp>
void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const
{
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
this->empty_key_sentinel(), this->erased_key_sentinel()};

auto storage_ref = this->storage_ref();
auto const op = [callback_op, is_filled, storage_ref] __device__(auto const window_slots) {
for (auto const slot : window_slots) {
if (is_filled(slot)) { callback_op(slot); }
}
};

CUCO_CUDA_TRY(cub::DeviceFor::ForEachCopyN(
sleeepyjack marked this conversation as resolved.
Show resolved Hide resolved
storage_ref.data(), storage_ref.num_windows(), op, stream.get()));
}

/**
* @brief For each key in the range [first, last), asynchronously applies the function object
* `callback_op` to the copy of all corresponding matches found in the container.
*
* @tparam InputIt Device accessible random access input iterator
* @tparam CallbackOp Type of unary callback function object
* @tparam Ref Type of non-owning device container ref allowing access to storage
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param callback_op Function to call on every match found in the container
* @param container_ref Non-owning device container ref used to access the slot storage
* @param stream CUDA stream used for this operation
*/
template <typename InputIt, typename CallbackOp, typename Ref>
void for_each_async(InputIt first,
InputIt last,
CallbackOp&& callback_op,
Ref container_ref,
cuda::stream_ref stream) const noexcept
{
auto const num_keys = cuco::detail::distance(first, last);
if (num_keys == 0) { return; }

auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);

detail::for_each_n<cg_size, cuco::detail::default_block_size()>
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
first, num_keys, std::forward<CallbackOp>(callback_op), container_ref);
}

/**
* @brief Gets the number of elements in the container
*
Expand Down
47 changes: 19 additions & 28 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -966,17 +966,14 @@ class open_addressing_ref_impl {
}

/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key.
* @brief For a given key, applies the function object `callback_op` to the copy of all
* corresponding matches found in the container.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
*
* @tparam ProbeKey Input type which is convertible to 'key_type'
* @tparam CallbackOp Unary callback functor or device lambda
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
*
* @param key The key to search for
* @param callback_op Function to call on every element found
* @param callback_op Function to apply to every match
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
Expand All @@ -995,7 +992,7 @@ class open_addressing_ref_impl {
return;
}
case detail::equal_result::EQUAL: {
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
callback_op(window_slots[i]);
continue;
}
default: continue;
Expand All @@ -1006,24 +1003,21 @@ class open_addressing_ref_impl {
}

/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
* @brief For a given key, applies the function object `callback_op` to the copy of all
* corresponding matches found in the container.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching element. If multiple elements are found within the same group,
* each thread with a match will call the callback with its associated element.
*
* @note Synchronizing `group` within `callback_op` is undefined behavior.
*
* @tparam ProbeKey Input type which is convertible to 'key_type'
* @tparam CallbackOp Unary callback functor or device lambda
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to call on every element found
* @param callback_op Function to apply to every match
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
Expand All @@ -1045,7 +1039,7 @@ class open_addressing_ref_impl {
continue;
}
case detail::equal_result::EQUAL: {
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
callback_op(window_slots[i]);
continue;
}
default: {
Expand All @@ -1060,12 +1054,9 @@ class open_addressing_ref_impl {
}

/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key and can additionally perform work that requires synchronizing the Cooperative Group
* performing this operation.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
* @brief Applies the function object `callback_op` to the copy of every slot in the container
* with key equivalent to the probe key and can additionally perform work that requires
* synchronizing the Cooperative Group performing this operation.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching element. If multiple elements are found within the same group,
Expand All @@ -1078,13 +1069,13 @@ class open_addressing_ref_impl {
* synchronization points is capped by `window_size * cg_size`. The functor will be called right
* after the current probing window has been traversed.
*
* @tparam ProbeKey Input type which is convertible to 'key_type'
* @tparam CallbackOp Unary callback functor or device lambda
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
* @tparam SyncOp Functor or device lambda which accepts the current `group` object
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to call on every element found
* @param callback_op Function to apply to every match
* @param sync_op Function that is allowed to synchronize `group` inbetween probing windows
*/
template <class ProbeKey, class CallbackOp, class SyncOp>
Expand All @@ -1108,7 +1099,7 @@ class open_addressing_ref_impl {
continue;
}
case detail::equal_result::EQUAL: {
callback_op(const_iterator{&(*(this->storage_ref_.data() + *probing_iter))[i]});
callback_op(window_slots[i]);
continue;
}
default: {
Expand Down
64 changes: 64 additions & 0 deletions include/cuco/detail/static_map/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,70 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename CallbackOp>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename CallbackOp>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const noexcept
{
impl_->for_each_async(
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
67 changes: 67 additions & 0 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -1249,5 +1249,72 @@ class operator_impl<
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<
op::for_each_tag,
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using iterator = typename base_type::iterator;
using const_iterator = typename base_type::const_iterator;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;

public:
/**
* @brief For a given key, applies the function object `callback_op` to its match found in the
* container.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
*
* @param key The key to search for
* @param callback_op Function to apply to the match
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(key, std::forward<CallbackOp>(callback_op));
}

/**
* @brief For a given key, applies the function object `callback_op` to its match found in the
* container.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching element. If multiple elements are found within the same group,
* each thread with a match will call the callback with its associated element.
*
* @note Synchronizing `group` within `callback_op` is undefined behavior.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Type of unary callback function object
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to apply to the match
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key,
CallbackOp&& callback_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(group, key, std::forward<CallbackOp>(callback_op));
}
};

} // namespace detail
} // namespace cuco
Loading
Loading