Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EC/ROCM: Prod overload issue for HIP complex -v1.2.x #783

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 95 additions & 5 deletions src/components/ec/rocm/kernel/ec_rocm_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ec_rocm.h"
#include "utils/ucc_math_op.h"
#include <inttypes.h>
#include <hip/hip_complex.h>

#define ROCM_REDUCE_WITH_OP_DEFAULT(NAME, _OP) \
template <typename _Type, typename _AlphaType> \
Expand Down Expand Up @@ -54,6 +55,41 @@
} \
}

#define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(NAME, _OP) \
template <typename _Type, typename _AlphaType> \
__global__ void UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME(ucc_eee_task_reduce_t task, \
uint16_t flags) \
{ \
size_t start = blockIdx.x * blockDim.x + threadIdx.x; \
size_t step = blockDim.x * gridDim.x; \
size_t count = task.count; \
int n_srcs = task.n_srcs; \
const _Type **s = (const _Type **)task.srcs; \
_Type * d = (_Type *)task.dst; \
size_t i; \
\
switch (n_srcs) { \
case 2: \
for (i = start; i < count; i += step) { \
d[i] = _OP(s[0][i], s[1][i]); \
} \
break; \
default: \
for (i = start; i < count; i += step) { \
d[i] = _OP(s[0][i], s[1][i]); \
for (size_t j = 2; j < n_srcs; j++) { \
d[i] = _OP(d[i], s[j][i]); \
} \
} \
break; \
} \
if (flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA) { \
for (i = start; i < count; i += step) { \
d[i] = d[i] * (_AlphaType)task.alpha; \
} \
} \
}

#define ROCM_REDUCE_WITH_OP_STRIDED(NAME, _OP) \
template <typename _Type, typename _AlphaType> \
__global__ void UCC_REDUCE_ROCM_STRIDED_##NAME( \
Expand Down Expand Up @@ -99,8 +135,45 @@
} \
}

#define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(NAME, _OP) \
template <typename _Type, typename _AlphaType> \
__global__ void UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME( \
const _Type *s1, const _Type *s2, _Type *d, size_t count, \
size_t stride, uint16_t n_src2, const bool with_alpha, \
const double alpha) \
{ \
size_t start = blockIdx.x * blockDim.x + threadIdx.x; \
size_t step = blockDim.x * gridDim.x; \
size_t ld = stride / sizeof(_Type); \
size_t i; \
\
ucc_assert_system(stride % sizeof(_Type) == 0); \
switch (n_src2) { \
case 1: \
for (i = start; i < count; i += step) { \
d[i] = _OP(s1[i], s2[i]); \
} \
break; \
default: \
for (i = start; i < count; i += step) { \
d[i] = _OP(s1[i], s2[i]); \
for (size_t j = 1; j < n_src2; j++) { \
d[i] = _OP(d[i], s2[i + j * ld]); \
} \
} \
break; \
} \
if (with_alpha) { \
for (i = start; i < count; i += step) { \
d[i] = d[i] * (_AlphaType)alpha; \
} \
} \
}

ROCM_REDUCE_WITH_OP_DEFAULT(SUM, DO_OP_SUM);
ROCM_REDUCE_WITH_OP_DEFAULT(PROD, DO_OP_PROD);
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(PROD_DOUBLE, hipCmul);
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(PROD_FLOAT, hipCmulf);
ROCM_REDUCE_WITH_OP_DEFAULT(MIN, DO_OP_MIN);
ROCM_REDUCE_WITH_OP_DEFAULT(MAX, DO_OP_MAX);
ROCM_REDUCE_WITH_OP_DEFAULT(LAND, DO_OP_LAND);
Expand All @@ -112,6 +185,8 @@ ROCM_REDUCE_WITH_OP_DEFAULT(BXOR, DO_OP_BXOR);

ROCM_REDUCE_WITH_OP_STRIDED(SUM, DO_OP_SUM);
ROCM_REDUCE_WITH_OP_STRIDED(PROD, DO_OP_PROD);
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(PROD_DOUBLE, hipCmul);
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(PROD_FLOAT, hipCmulf);
ROCM_REDUCE_WITH_OP_STRIDED(MIN, DO_OP_MIN);
ROCM_REDUCE_WITH_OP_STRIDED(MAX, DO_OP_MAX);
ROCM_REDUCE_WITH_OP_STRIDED(LAND, DO_OP_LAND);
Expand All @@ -136,6 +211,21 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
} \
} while (0)

#define LAUNCH_KERNEL_B(NAME, type, _AlphaType, _task, s, b, t) \
do { \
if (_task->task_type == UCC_EE_EXECUTOR_TASK_REDUCE) { \
UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME<type, _AlphaType> \
<<<b, t, 0, s>>>(_task->reduce, _task->flags); \
} else { \
ucc_eee_task_reduce_strided_t *trs = &_task->reduce_strided; \
UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME<type, _AlphaType><<<b, t, 0, s>>>( \
(type *)trs->src1, (type *)trs->src2, (type *)trs->dst, \
trs->count, trs->stride, trs->n_src2, \
(bool)(_task->flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA), \
trs->alpha); \
} \
} while (0)

#define LAUNCH_KERNEL(NAME, type, _task, s, b, t) \
LAUNCH_KERNEL_A(NAME, type, type, _task, s, b, t)

Expand Down Expand Up @@ -207,15 +297,15 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
} \
} while (0)

#define DT_REDUCE_FLOAT_COMPLEX(type, _alphaType, _task, _op, s, b, t) \
#define DT_REDUCE_FLOAT_COMPLEX(NAME, type, _alphaType, _task, _op, s, b, t) \
do { \
switch (_op) { \
case UCC_OP_AVG: \
case UCC_OP_SUM: \
LAUNCH_KERNEL_A(SUM, type , _alphaType, _task, s, b, t); \
LAUNCH_KERNEL_A(SUM, type , _alphaType, _task, s, b, t); \
break; \
case UCC_OP_PROD: \
LAUNCH_KERNEL_A(PROD, type, _alphaType, _task, s, b, t); \
LAUNCH_KERNEL_B(NAME, type, _alphaType, _task, s, b, t); \
break; \
default: \
ec_error(&ucc_ec_rocm.super, \
Expand Down Expand Up @@ -299,10 +389,10 @@ ucc_status_t ucc_ec_rocm_reduce(ucc_ee_executor_task_args_t *task,
return UCC_ERR_NOT_SUPPORTED;
#endif
case UCC_DT_FLOAT32_COMPLEX:
DT_REDUCE_FLOAT_COMPLEX(hipFloatComplex, float, task, op, stream, bk, th);
DT_REDUCE_FLOAT_COMPLEX(PROD_FLOAT, hipFloatComplex, float, task, op, stream, bk, th);
break;
case UCC_DT_FLOAT64_COMPLEX:
DT_REDUCE_FLOAT_COMPLEX(hipDoubleComplex, double, task, op, stream, bk, th);
DT_REDUCE_FLOAT_COMPLEX(PROD_DOUBLE, hipDoubleComplex, double, task, op, stream, bk, th);
break;
case UCC_DT_BFLOAT16:
ucc_assert(2 == sizeof(hip_bfloat16));
Expand Down