Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add supports for AMD GPU #6161

Merged
merged 7 commits into from
May 24, 2023
Merged
94 changes: 46 additions & 48 deletions monai/_extensions/gmm/gmm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ limitations under the License.
#define EPSILON 1e-5
#define BLOCK_SIZE 32
#define TILE(SIZE, STRIDE) ((((SIZE)-1) / (STRIDE)) + 1)
#ifdef __HIP_PLATFORM_AMD__
#define __SHFL_DOWN(a, b) __shfl_down(a, b)
#define __SHFL_XOR(a, b) __shfl_xor(a, b)
#else
#define __SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff, a, b)
#define __SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b)
#endif

template <int warp_count, int load_count>
__global__ void CovarianceReductionKernel(
Expand Down Expand Up @@ -82,13 +89,11 @@ __global__ void CovarianceReductionKernel(

for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) {
float matrix_component = matrix[i];

matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);

matrix_component += __SHFL_DOWN(matrix_component, 16);
matrix_component += __SHFL_DOWN(matrix_component, 8);
matrix_component += __SHFL_DOWN(matrix_component, 4);
matrix_component += __SHFL_DOWN(matrix_component, 2);
matrix_component += __SHFL_DOWN(matrix_component, 1);
if (lane_index == 0) {
s_matrix_component[warp_index] = matrix_component;
}
Expand All @@ -97,23 +102,21 @@ __global__ void CovarianceReductionKernel(

if (warp_index == 0) {
matrix_component = s_matrix_component[lane_index];

if (warp_count >= 32) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __SHFL_DOWN(matrix_component, 16);
}
if (warp_count >= 16) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __SHFL_DOWN(matrix_component, 8);
}
if (warp_count >= 8) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __SHFL_DOWN(matrix_component, 4);
}
if (warp_count >= 4) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __SHFL_DOWN(matrix_component, 2);
}
if (warp_count >= 2) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
matrix_component += __SHFL_DOWN(matrix_component, 1);
}

if (lane_index == 0) {
g_batch_matrices[matrix_offset + i] = matrix_component;
}
Expand Down Expand Up @@ -156,13 +159,11 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index];
}
}

matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);

matrix_component += __SHFL_DOWN(matrix_component, 16);
matrix_component += __SHFL_DOWN(matrix_component, 8);
matrix_component += __SHFL_DOWN(matrix_component, 4);
matrix_component += __SHFL_DOWN(matrix_component, 2);
matrix_component += __SHFL_DOWN(matrix_component, 1);
if (lane_index == 0) {
s_matrix_component[warp_index] = matrix_component;
}
Expand All @@ -171,23 +172,21 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g

if (warp_index == 0) {
matrix_component = s_matrix_component[lane_index];

if (warp_count >= 32) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __SHFL_DOWN(matrix_component, 16);
}
if (warp_count >= 16) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __SHFL_DOWN(matrix_component, 8);
}
if (warp_count >= 8) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __SHFL_DOWN(matrix_component, 4);
}
if (warp_count >= 4) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __SHFL_DOWN(matrix_component, 2);
}
if (warp_count >= 2) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
matrix_component += __SHFL_DOWN(matrix_component, 1);
}

if (lane_index == 0) {
float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j];

Expand Down Expand Up @@ -261,13 +260,11 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) {
}

float max_value = eigenvalue;

max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));

max_value = max(max_value, __SHFL_XOR(max_value, 16));
max_value = max(max_value, __SHFL_XOR(max_value, 8));
max_value = max(max_value, __SHFL_XOR(max_value, 4));
max_value = max(max_value, __SHFL_XOR(max_value, 2));
max_value = max(max_value, __SHFL_XOR(max_value, 1));
if (max_value == eigenvalue) {
GMMSplit_t split;

Expand Down Expand Up @@ -347,12 +344,11 @@ __global__ void GMMcommonTerm(float* g_gmm) {
float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f;

float sum = gmm_n;

sum += __shfl_xor_sync(0xffffffff, sum, 1);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __SHFL_XOR(sum, 1);
sum += __SHFL_XOR(sum, 2);
sum += __SHFL_XOR(sum, 4);
sum += __SHFL_XOR(sum, 8);
sum += __SHFL_XOR(sum, 16);

if (threadIdx.x < MIXTURE_SIZE) {
float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON;
Expand Down Expand Up @@ -446,13 +442,14 @@ void GMMInitialize(
for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) {
for (unsigned int i = 0; i < k; ++i) {
CovarianceReductionKernel<WARPS, LOAD>
<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
<<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
}

CovarianceFinalizationKernel<WARPS, false><<<{k, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
CovarianceFinalizationKernel<WARPS, false><<<dim3(k, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);

GMMFindSplit<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm);
GMMDoSplit<<<{TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count}, BLOCK_SIZE>>>(
GMMFindSplit<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(
gmm_split_scratch, k / MIXTURE_COUNT, gmm);
GMMDoSplit<<<dim3(TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count), BLOCK_SIZE>>>(
gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count);
}
}
Expand All @@ -472,12 +469,13 @@ void GMMUpdate(

for (unsigned int i = 0; i < gmm_N; ++i) {
CovarianceReductionKernel<WARPS, LOAD>
<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
<<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
}

CovarianceFinalizationKernel<WARPS, true><<<{gmm_N, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
CovarianceFinalizationKernel<WARPS, true>
<<<dim3(gmm_N, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);

GMMcommonTerm<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
GMMcommonTerm<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
}

void GMMDataTerm(
Expand Down