Skip to content

Commit

Permalink
Use cooperative-groups instead of cub warp-reduce for strings contains (
Browse files Browse the repository at this point in the history
#17540)

Replaces the `cub::WarpReduce` usage in `cudf::strings::contains` with cooperative-groups `any()`.
The change is only for the `contains_warp_parallel` kernel which is used for wider strings.
Using cooperative-groups generates more efficient code for the same results and gives an additional 11-14% performance improvement.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Yunsong Wang (https://github.com/PointKernel)
  - Nghia Truong (https://github.com/ttnghia)
  - Shruti Shivakumar (https://github.com/shrshi)

URL: #17540
  • Loading branch information
davidwendt authored Dec 9, 2024
1 parent 80fc629 commit a0fc6a8
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <cooperative_groups.h>
#include <cuda/atomic>
#include <thrust/binary_search.h>
#include <thrust/fill.h>
Expand Down Expand Up @@ -347,13 +348,15 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings,
string_view const d_target,
bool* d_results)
{
auto const idx = cudf::detail::grid_1d::global_thread_id();
using warp_reduce = cub::WarpReduce<bool>;
__shared__ typename warp_reduce::TempStorage temp_storage;
auto const idx = cudf::detail::grid_1d::global_thread_id();

auto const str_idx = idx / cudf::detail::warp_size;
if (str_idx >= d_strings.size()) { return; }
auto const lane_idx = idx % cudf::detail::warp_size;

namespace cg = cooperative_groups;
auto const warp = cg::tiled_partition<cudf::detail::warp_size>(cg::this_thread_block());
auto const lane_idx = warp.thread_rank();

if (d_strings.is_null(str_idx)) { return; }
// get the string for this warp
auto const d_str = d_strings.element<string_view>(str_idx);
Expand All @@ -373,7 +376,7 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings,
}
}

auto const result = warp_reduce(temp_storage).Reduce(found, cub::Max());
auto const result = warp.any(found);
if (lane_idx == 0) { d_results[str_idx] = result; }
}

Expand Down

0 comments on commit a0fc6a8

Please sign in to comment.