Skip to content

Commit ecbe8f8

Browse files
elstehledavebayer
authored andcommitted
Adds support for large number of segments to DeviceSegmentedReduce (NVIDIA#3764)
* add support for large num segments on device level * adds support for large number of segments on dispatch * refactors offset iterator * add tests for large number of segments * fixes style * renames offset iterator to snake case * rely on ctad instead of factory function * adds tests for more device interfaces * use offset_input_iterator where applicable * [skip-ci] addresses review comments * fixes msvc implicit conversion warning * drops debug print utilities * removes argmin/max wrappers * fixes style * fixes include order * fixes nvrtc * expects user iterators to be advancable on the host * drops redundant include * adds workaround for c.parallel indirect_arg_t * adds todo * uses cuda::std traits * adds missing exec space specifiers
1 parent 255ea1a commit ecbe8f8

File tree

5 files changed

+359
-34
lines changed

5 files changed

+359
-34
lines changed

cub/cub/device/device_segmented_reduce.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ private:
9191
size_t& temp_storage_bytes,
9292
InputIteratorT d_in,
9393
OutputIteratorT d_out,
94-
int num_segments,
94+
::cuda::std::int64_t num_segments,
9595
BeginOffsetIteratorT d_begin_offsets,
9696
EndOffsetIteratorT d_end_offsets,
9797
ReductionOpT reduction_op,
@@ -112,7 +112,7 @@ private:
112112
size_t& temp_storage_bytes,
113113
InputIteratorT d_in,
114114
OutputIteratorT d_out,
115-
int num_segments,
115+
::cuda::std::int64_t num_segments,
116116
BeginOffsetIteratorT d_begin_offsets,
117117
EndOffsetIteratorT d_end_offsets,
118118
ReductionOpT reduction_op,
@@ -243,7 +243,7 @@ public:
243243
size_t& temp_storage_bytes,
244244
InputIteratorT d_in,
245245
OutputIteratorT d_out,
246-
int num_segments,
246+
::cuda::std::int64_t num_segments,
247247
BeginOffsetIteratorT d_begin_offsets,
248248
EndOffsetIteratorT d_end_offsets,
249249
ReductionOpT reduction_op,
@@ -355,7 +355,7 @@ public:
355355
size_t& temp_storage_bytes,
356356
InputIteratorT d_in,
357357
OutputIteratorT d_out,
358-
int num_segments,
358+
::cuda::std::int64_t num_segments,
359359
BeginOffsetIteratorT d_begin_offsets,
360360
EndOffsetIteratorT d_end_offsets,
361361
cudaStream_t stream = 0)
@@ -478,7 +478,7 @@ public:
478478
size_t& temp_storage_bytes,
479479
InputIteratorT d_in,
480480
OutputIteratorT d_out,
481-
int num_segments,
481+
::cuda::std::int64_t num_segments,
482482
BeginOffsetIteratorT d_begin_offsets,
483483
EndOffsetIteratorT d_end_offsets,
484484
cudaStream_t stream = 0)
@@ -605,7 +605,7 @@ public:
605605
size_t& temp_storage_bytes,
606606
InputIteratorT d_in,
607607
OutputIteratorT d_out,
608-
int num_segments,
608+
::cuda::std::int64_t num_segments,
609609
BeginOffsetIteratorT d_begin_offsets,
610610
EndOffsetIteratorT d_end_offsets,
611611
cudaStream_t stream = 0)
@@ -744,7 +744,7 @@ public:
744744
size_t& temp_storage_bytes,
745745
InputIteratorT d_in,
746746
OutputIteratorT d_out,
747-
int num_segments,
747+
::cuda::std::int64_t num_segments,
748748
BeginOffsetIteratorT d_begin_offsets,
749749
EndOffsetIteratorT d_end_offsets,
750750
cudaStream_t stream = 0)
@@ -869,7 +869,7 @@ public:
869869
size_t& temp_storage_bytes,
870870
InputIteratorT d_in,
871871
OutputIteratorT d_out,
872-
int num_segments,
872+
::cuda::std::int64_t num_segments,
873873
BeginOffsetIteratorT d_begin_offsets,
874874
EndOffsetIteratorT d_end_offsets,
875875
cudaStream_t stream = 0)

cub/cub/device/dispatch/dispatch_common.cuh

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
#include <cub/config.cuh>
77

8+
#include <cuda/std/type_traits>
9+
10+
#include "cuda/std/__cccl/execution_space.h"
11+
812
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
913
# pragma GCC system_header
1014
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
@@ -40,4 +44,51 @@ enum class SelectImpl
4044
Partition
4145
};
4246

47+
namespace detail
48+
{
49+
template <typename T, typename U, typename = void>
50+
struct has_plus_operator : ::cuda::std::false_type
51+
{};
52+
53+
template <typename T, typename U>
54+
struct has_plus_operator<T, U, ::cuda::std::void_t<decltype(::cuda::std::declval<T>() + ::cuda::std::declval<U>())>>
55+
: ::cuda::std::true_type
56+
{};
57+
58+
template <typename T, typename U>
59+
constexpr bool has_plus_operator_v = has_plus_operator<T, U>::value;
60+
61+
// Helper function that advances a given iterator only if it supports being advanced by the given offset
62+
template <typename IteratorT, typename OffsetT>
63+
_CCCL_HOST_DEVICE IteratorT advance_iterators_if_supported(IteratorT iter, OffsetT offset)
64+
{
65+
if constexpr (has_plus_operator_v<IteratorT, OffsetT>)
66+
{
67+
// If operator+ is valid, advance the iterator.
68+
return iter + offset;
69+
}
70+
else
71+
{
72+
// Otherwise, return iter unmodified.
73+
static_cast<void>(offset);
74+
return iter;
75+
}
76+
}
77+
78+
// Helper function that checks whether all of the given iterators support the + operator with the given offset
79+
template <typename OffsetT, typename... Iterators>
80+
_CCCL_HOST_DEVICE bool all_iterators_support_plus_operator(OffsetT /*offset*/, Iterators... /*iters*/)
81+
{
82+
if constexpr ((has_plus_operator_v<Iterators, OffsetT> && ...))
83+
{
84+
return true;
85+
}
86+
else
87+
{
88+
return false;
89+
}
90+
}
91+
92+
} // namespace detail
93+
4394
CUB_NAMESPACE_END

cub/cub/device/dispatch/dispatch_reduce.cuh

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
#include <cub/detail/launcher/cuda_runtime.cuh>
4848
#include <cub/detail/type_traits.cuh> // for cub::detail::invoke_result_t
49+
#include <cub/device/dispatch/dispatch_common.cuh>
4950
#include <cub/device/dispatch/kernels/reduce.cuh>
5051
#include <cub/device/dispatch/kernels/segmented_reduce.cuh>
5152
#include <cub/device/dispatch/tuning/tuning_reduce.cuh>
@@ -706,7 +707,7 @@ struct DispatchSegmentedReduce
706707
OutputIteratorT d_out;
707708

708709
/// The number of segments that comprise the sorting data
709-
int num_segments;
710+
::cuda::std::int64_t num_segments;
710711

711712
/// Random-access input iterator to the sequence of beginning offsets of
712713
/// length `num_segments`, such that `d_begin_offsets[i]` is the first
@@ -747,7 +748,7 @@ struct DispatchSegmentedReduce
747748
size_t& temp_storage_bytes,
748749
InputIteratorT d_in,
749750
OutputIteratorT d_out,
750-
int num_segments,
751+
::cuda::std::int64_t num_segments,
751752
BeginOffsetIteratorT d_begin_offsets,
752753
EndOffsetIteratorT d_end_offsets,
753754
ReductionOpT reduction_op,
@@ -813,33 +814,61 @@ struct DispatchSegmentedReduce
813814
break;
814815
}
815816

816-
// Log device_reduce_sweep_kernel configuration
817-
#ifdef CUB_DEBUG_LOG
818-
_CubLog("Invoking SegmentedDeviceReduceKernel<<<%d, %d, 0, %lld>>>(), "
819-
"%d items per thread, %d SM occupancy\n",
820-
num_segments,
821-
policy.SegmentedReduce().BlockThreads(),
822-
(long long) stream,
823-
policy.SegmentedReduce().ItemsPerThread(),
824-
segmented_reduce_config.sm_occupancy);
825-
#endif // CUB_DEBUG_LOG
826-
827-
// Invoke DeviceReduceKernel
828-
launcher_factory(num_segments, policy.SegmentedReduce().BlockThreads(), 0, stream)
829-
.doit(segmented_reduce_kernel, d_in, d_out, d_begin_offsets, d_end_offsets, num_segments, reduction_op, init);
817+
const auto num_segments_per_invocation =
818+
static_cast<::cuda::std::int64_t>(::cuda::std::numeric_limits<::cuda::std::int32_t>::max());
819+
const ::cuda::std::int64_t num_invocations = ::cuda::ceil_div(num_segments, num_segments_per_invocation);
830820

831-
// Check for failure to launch
832-
error = CubDebug(cudaPeekAtLastError());
833-
if (cudaSuccess != error)
821+
// If we need multiple passes over the segments but the iterators do not support the + operator, we cannot use the
822+
// streaming approach and have to fail, returning cudaErrorInvalidValue. This is because c.parallel passes
823+
// indirect_arg_t as the iterator type, which does not support the + operator.
824+
// TODO (elstehle): Remove this check once https://github.com/NVIDIA/cccl/issues/4148 is resolved.
825+
if (num_invocations > 1
826+
&& !detail::all_iterators_support_plus_operator(::cuda::std::int64_t{}, d_out, d_begin_offsets, d_end_offsets))
834827
{
835-
break;
828+
return cudaErrorInvalidValue;
836829
}
837830

838-
// Sync the stream if specified to flush runtime errors
839-
error = CubDebug(detail::DebugSyncStream(stream));
840-
if (cudaSuccess != error)
831+
for (::cuda::std::int64_t invocation_index = 0; invocation_index < num_invocations; invocation_index++)
841832
{
842-
break;
833+
const auto current_seg_offset = invocation_index * num_segments_per_invocation;
834+
const auto num_current_segments =
835+
::cuda::std::min(num_segments_per_invocation, num_segments - current_seg_offset);
836+
837+
// Log device_reduce_sweep_kernel configuration
838+
#ifdef CUB_DEBUG_LOG
839+
_CubLog("Invoking SegmentedDeviceReduceKernel<<<%ld, %d, 0, %lld>>>(), "
840+
"%d items per thread, %d SM occupancy\n",
841+
num_current_segments,
842+
policy.SegmentedReduce().BlockThreads(),
843+
(long long) stream,
844+
policy.SegmentedReduce().ItemsPerThread(),
845+
segmented_reduce_config.sm_occupancy);
846+
#endif // CUB_DEBUG_LOG
847+
848+
// Invoke DeviceReduceKernel
849+
launcher_factory(
850+
static_cast<::cuda::std::uint32_t>(num_current_segments), policy.SegmentedReduce().BlockThreads(), 0, stream)
851+
.doit(segmented_reduce_kernel,
852+
d_in,
853+
detail::advance_iterators_if_supported(d_out, current_seg_offset),
854+
detail::advance_iterators_if_supported(d_begin_offsets, current_seg_offset),
855+
detail::advance_iterators_if_supported(d_end_offsets, current_seg_offset),
856+
reduction_op,
857+
init);
858+
859+
// Check for failure to launch
860+
error = CubDebug(cudaPeekAtLastError());
861+
if (cudaSuccess != error)
862+
{
863+
break;
864+
}
865+
866+
// Sync the stream if specified to flush runtime errors
867+
error = CubDebug(detail::DebugSyncStream(stream));
868+
if (cudaSuccess != error)
869+
{
870+
break;
871+
}
843872
}
844873
} while (0);
845874

@@ -908,7 +937,7 @@ struct DispatchSegmentedReduce
908937
size_t& temp_storage_bytes,
909938
InputIteratorT d_in,
910939
OutputIteratorT d_out,
911-
int num_segments,
940+
::cuda::std::int64_t num_segments,
912941
BeginOffsetIteratorT d_begin_offsets,
913942
EndOffsetIteratorT d_end_offsets,
914943
ReductionOpT reduction_op,

cub/cub/device/dispatch/kernels/segmented_reduce.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)
132132
OutputIteratorT d_out,
133133
BeginOffsetIteratorT d_begin_offsets,
134134
EndOffsetIteratorT d_end_offsets,
135-
int /*num_segments*/,
136135
ReductionOpT reduction_op,
137136
InitT init)
138137
{

0 commit comments

Comments
 (0)