Skip to content
1 change: 0 additions & 1 deletion csrc/includes/conversion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

#include "ds_kernel_utils.h"

#include <cuda_fp16.h>
#include <stdint.h>

#ifdef BF16_AVAILABLE
Expand Down
2 changes: 2 additions & 0 deletions csrc/includes/ds_kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ used throughout the codebase.
constexpr int hw_warp_size = 64;
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_fp16.h>

#else // !__HIP_PLATFORM_HCC__

Expand All @@ -37,6 +38,7 @@ constexpr int hw_warp_size = 32;
#endif // __CUDA_ARCH__ >= 800

#include <cooperative_groups.h>
#include <cuda_fp16.h>

#endif //__HIP_PLATFORM_HCC__

Expand Down
76 changes: 44 additions & 32 deletions csrc/includes/reduction_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ of reduce should be straightforward (can just wrap the sum reduction) and
would be a good extension of the header.
*/

DS_D_INLINE int _warp_rank()
{
const int thread_rank =
threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;
return thread_rank / hw_warp_size;
}

/* Float element reduce implementations */
template <>
DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
Expand Down Expand Up @@ -273,22 +280,34 @@ DS_D_INLINE __half init<ROpType::Max>()
template <>
DS_D_INLINE __half2 init<ROpType::Add>()
{
#ifdef __HIP_PLATFORM_HCC__
return __half2{_Float16_2{0x0000, 0x0000}};
#else
constexpr __half2_raw zero = {0x0000, 0x0000};
return __half2(zero);
#endif
}

template <>
DS_D_INLINE __half2 init<ROpType::Min>()
{
#ifdef __HIP_PLATFORM_HCC__
return __half2{_Float16_2{0x7C00, 0x7C00}};
#else
constexpr __half2_raw inf = {0x7C00, 0x7C00};
return __half2(inf);
#endif
}

template <>
DS_D_INLINE __half2 init<ROpType::Max>()
{
#ifdef __HIP_PLATFORM_HCC__
return __half2{_Float16_2{0xFC00, 0xFC00}};
#else
constexpr __half2_raw neg_inf = {0xFC00, 0xFC00};
return __half2(neg_inf);
#endif
}

template <ROpType Op, typename T>
Expand Down Expand Up @@ -379,48 +398,46 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
Implementation for primary block reduction that serves both `block` and
`partitioned_block`.

`local_warp_rank` refers to the warp's location within the partition, so
for an unpartitioned threadblock this will be equivalent to
`warp_arg.meta_group_rank()`.

Similarly, the warp offset is the `local_warp_rank` of the warp with the
lowest rank in the partition. In the case of an 8 warp block with a
4 warp reduction, this would map to [0, 0, 0, 0, 4, 4, 4, 4].

Partition size is the number of warps per partition (equal to the thread
block in the default case). This enables us to only perform the warp reduction
when able to.
Total warps refers to the reduction width of the reduction, not
the number of warps in the block (which may exceed that
if the block is partitioned or if we do a conservative bound at
compile time).
*/
template <int total_warps, ROpType... Ops>
DS_D_INLINE void _block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp_arg,
float* data,
int warp_offset)
float* data)
{
constexpr int elems = sizeof...(Ops);
// Separated for now in case this no longer is true
constexpr int bytes = sizeof(float);
// Unused when `partition_size == 1` or total_warps == 1
__shared__ float reduce_buffer[max_warps * elems];

#ifdef __HIP_PLATFORM_HCC__
const int total_threads = blockDim.x * blockDim.y * blockDim.z;
const int running_warps = total_threads / hw_warp_size;
#else
const int running_warps = warp_arg.meta_group_size();
#endif

// Always perform warp-scope reduction
_warp<Ops...>(warp_arg, data);

// If max_warps == 1 let's skip the runtime check
if (warp_arg.meta_group_size() > 1 && total_warps != 1) {
if (total_warps != 1) {
if (warp_arg.thread_rank() == 0) {
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::store_shared<bytes>(
reduce_buffer + elems * warp_arg.meta_group_rank() + i, data + i);
mem_access::store_shared<bytes>(reduce_buffer + elems * _warp_rank() + i, data + i);
}
}

// Synchronization inside block-uniform conditional is safe
tb.sync();

if (warp_arg.meta_group_rank() == 0) {
if (warp_arg.thread_rank() < warp_arg.meta_group_size()) {
if (_warp_rank() == 0) {
if (warp_arg.thread_rank() < running_warps) {
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::load_shared<bytes>(
Expand All @@ -444,8 +461,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,

#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::load_shared<bytes>(data + i,
reduce_buffer + warp_arg.meta_group_rank() * elems + i);
mem_access::load_shared<bytes>(data + i, reduce_buffer + _warp_rank() * elems + i);
}
}
}
Expand All @@ -460,7 +476,7 @@ us to obfuscate the details of the partitioned implementation.
template <ROpType Op, int warp_bound>
DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
{
_block<warp_bound, Op>(tb, warp, &val, 0);
_block<warp_bound, Op>(tb, warp, &val);
}

template <ROpType Op1, ROpType Op2, int warp_bound>
Expand All @@ -470,7 +486,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val2)
{
float data[2] = {val1, val2};
_block<warp_bound, Op1, Op2>(tb, warp, data, 0);
_block<warp_bound, Op1, Op2>(tb, warp, data);
val1 = data[0];
val2 = data[1];
}
Expand All @@ -483,7 +499,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val3)
{
float data[3] = {val1, val2, val3};
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data, 0);
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data);
val1 = data[0];
val2 = data[1];
val3 = data[2];
Expand All @@ -498,7 +514,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val4)
{
float data[4] = {val1, val2, val3, val4};
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data, 0);
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
val1 = data[0];
val2 = data[1];
val3 = data[2];
Expand All @@ -518,8 +534,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
_warp<Op, num_threads>(warp, &val);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op>(tb, warp, &val, warp_offset);
_block<num_warps, Op>(tb, warp, &val);
}
}

Expand All @@ -535,8 +550,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
_warp<Op1, Op2, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2>(tb, warp, data, warp_offset);
_block<num_warps, Op1, Op2>(tb, warp, data);
}

val1 = data[0];
Expand All @@ -556,8 +570,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
_warp<Op1, Op2, Op3, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2, Op3>(tb, warp, data, warp_offset);
_block<num_warps, Op1, Op2, Op3>(tb, warp, data);
}

val1 = data[0];
Expand All @@ -579,8 +592,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
_warp<Op1, Op2, Op3, Op4, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data, warp_offset);
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
}

val1 = data[0];
Expand Down