Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add supports for AMD GPU #6161

Merged
merged 7 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions Dockerfile.amd
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# To build with a different base image
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
#ARG ARCH=ROCM
ARG PYTORCH_IMAGE=rocm/pytorch:rocm5.4_ubuntu20.04_py3.8_pytorch_1.12.1
FROM ${PYTORCH_IMAGE}

LABEL maintainer="monai.contact@gmail.com"

WORKDIR /opt/monai

# install full deps
COPY requirements.txt requirements-min.txt requirements-dev.txt /tmp/
RUN cp /tmp/requirements.txt /tmp/req.bak \
&& sed -i '/cucim/d' /tmp/requirements-dev.txt \
&& awk '!/torch/' /tmp/requirements.txt > /tmp/tmp && mv /tmp/tmp /tmp/requirements.txt \
&& python -m pip install --upgrade --no-cache-dir pip \
&& python -m pip install --no-cache-dir -r /tmp/requirements-dev.txt


# compile ext and remove temp files
# TODO: remark for issue [revise the dockerfile #1276](https://github.com/Project-MONAI/MONAI/issues/1276)
# please specify exact files and folders to be copied -- else, basically always, the Docker build process cannot cache
# this or anything below it and always will build from at most here; one file change leads to no caching from here on...

COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./
COPY tests ./tests
COPY monai ./monai
RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \
&& rm -rf build __pycache__
WORKDIR /opt/monai
95 changes: 76 additions & 19 deletions monai/_extensions/gmm/gmm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,19 @@ __global__ void CovarianceReductionKernel(

for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) {
float matrix_component = matrix[i];

#ifdef __HIP_PLATFORM_AMD__
matrix_component += __shfl_down(matrix_component, 16);
wyli marked this conversation as resolved.
Show resolved Hide resolved
matrix_component += __shfl_down(matrix_component, 8);
matrix_component += __shfl_down(matrix_component, 4);
matrix_component += __shfl_down(matrix_component, 2);
matrix_component += __shfl_down(matrix_component, 1);
#else
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);

#endif
if (lane_index == 0) {
s_matrix_component[warp_index] = matrix_component;
}
Expand All @@ -97,7 +103,23 @@ __global__ void CovarianceReductionKernel(

if (warp_index == 0) {
matrix_component = s_matrix_component[lane_index];

#ifdef __HIP_PLATFORM_AMD__
if (warp_count >= 32) {
matrix_component += __shfl_down(matrix_component, 16);
}
if (warp_count >= 16) {
matrix_component += __shfl_down(matrix_component, 8);
}
if (warp_count >= 8) {
matrix_component += __shfl_down(matrix_component, 4);
}
if (warp_count >= 4) {
matrix_component += __shfl_down(matrix_component, 2);
}
if (warp_count >= 2) {
matrix_component += __shfl_down(matrix_component, 1);
}
#else
if (warp_count >= 32) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
}
Expand All @@ -113,7 +135,7 @@ __global__ void CovarianceReductionKernel(
if (warp_count >= 2) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
}

#endif
if (lane_index == 0) {
g_batch_matrices[matrix_offset + i] = matrix_component;
}
Expand Down Expand Up @@ -156,13 +178,19 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index];
}
}

#ifdef __HIP_PLATFORM_AMD__
matrix_component += __shfl_down(matrix_component, 16);
matrix_component += __shfl_down(matrix_component, 8);
matrix_component += __shfl_down(matrix_component, 4);
matrix_component += __shfl_down(matrix_component, 2);
matrix_component += __shfl_down(matrix_component, 1);
#else
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);

#endif
if (lane_index == 0) {
s_matrix_component[warp_index] = matrix_component;
}
Expand All @@ -171,7 +199,23 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g

if (warp_index == 0) {
matrix_component = s_matrix_component[lane_index];

#ifdef __HIP_PLATFORM_AMD__
if (warp_count >= 32) {
matrix_component += __shfl_down(matrix_component, 16);
}
if (warp_count >= 16) {
matrix_component += __shfl_down(matrix_component, 8);
}
if (warp_count >= 8) {
matrix_component += __shfl_down(matrix_component, 4);
}
if (warp_count >= 4) {
matrix_component += __shfl_down(matrix_component, 2);
}
if (warp_count >= 2) {
matrix_component += __shfl_down(matrix_component, 1);
}
#else
if (warp_count >= 32) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
}
Expand All @@ -187,7 +231,7 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
if (warp_count >= 2) {
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
}

#endif
if (lane_index == 0) {
float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j];

Expand Down Expand Up @@ -261,13 +305,19 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) {
}

float max_value = eigenvalue;

#ifdef __HIP_PLATFORM_AMD__
max_value = max(max_value, __shfl_xor(max_value, 16));
max_value = max(max_value, __shfl_xor(max_value, 8));
max_value = max(max_value, __shfl_xor(max_value, 4));
max_value = max(max_value, __shfl_xor(max_value, 2));
max_value = max(max_value, __shfl_xor(max_value, 1));
#else
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));

#endif
if (max_value == eigenvalue) {
GMMSplit_t split;

Expand Down Expand Up @@ -347,13 +397,20 @@ __global__ void GMMcommonTerm(float* g_gmm) {
float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f;

float sum = gmm_n;

#ifdef __HIP_PLATFORM_AMD__
sum += __shfl_xor(sum, 1);
sum += __shfl_xor(sum, 2);
sum += __shfl_xor(sum, 4);
sum += __shfl_xor(sum, 8);
sum += __shfl_xor(sum, 16);

#else
sum += __shfl_xor_sync(0xffffffff, sum, 1);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 16);

#endif
if (threadIdx.x < MIXTURE_SIZE) {
float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON;
float commonTerm = det > 0.0f ? gmm_n / (sqrtf(det) * sum) : gmm_n / sum;
Expand Down Expand Up @@ -446,13 +503,13 @@ void GMMInitialize(
for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) {
for (unsigned int i = 0; i < k; ++i) {
CovarianceReductionKernel<WARPS, LOAD>
<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
<<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
}

CovarianceFinalizationKernel<WARPS, false><<<{k, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
CovarianceFinalizationKernel<WARPS, false><<<dim3(k, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);

GMMFindSplit<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm);
GMMDoSplit<<<{TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count}, BLOCK_SIZE>>>(
GMMFindSplit<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm);
GMMDoSplit<<<dim3(TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count), BLOCK_SIZE>>>(
gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count);
}
}
Expand All @@ -472,12 +529,12 @@ void GMMUpdate(

for (unsigned int i = 0; i < gmm_N; ++i) {
CovarianceReductionKernel<WARPS, LOAD>
<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
<<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
}

CovarianceFinalizationKernel<WARPS, true><<<{gmm_N, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
CovarianceFinalizationKernel<WARPS, true><<<dim3(gmm_N, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);

GMMcommonTerm<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
GMMcommonTerm<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
}

void GMMDataTerm(
Expand Down