Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in Transpose CUDA kernel #7329

Merged
merged 11 commits into from
May 27, 2021
19 changes: 16 additions & 3 deletions onnxruntime/core/providers/cuda/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,26 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop,
if (CanDoTranspose3D(new_rank, new_input_dims, new_permutations)) {
return Transpose3DImpl(stream, element_size, input_shape, tmp_input_strides,
input.DataRaw(), output.MutableDataRaw(), output.Shape().Size());
} else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) {
} else if (CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(
prop, element_size, new_rank, new_input_dims, new_permutations)) {
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DImpl(stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
return Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
} else if (CanDoTranspose4DParallelizeOneElementPerThread(
prop, element_size, new_rank, new_input_dims, new_permutations)) {
// Trying to see if we can still do (best effort) more optimized transposing
// for the 4-D case before falling back to the generic case
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DParallelizeOneElementPerThread(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
}

// General cases
Expand Down
142 changes: 102 additions & 40 deletions onnxruntime/core/providers/cuda/tensor/transpose_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ Status Transpose3DImpl(cudaStream_t stream, size_t element_size,
return Status::OK();
}

template <int element_size>
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
__global__ void Transpose4DKernel(const TArray<int64_t> input_strides, const void* input_data,
const TArray<int64_t> output_strides, void* output_data,
CUDA_LONG N) {
__global__ void Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim(
const TArray<int64_t> input_strides, const void* input_data,
const TArray<int64_t> output_strides, void* output_data,
size_t element_size,
CUDA_LONG N) {
// output coordinates will be: blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x
CUDA_LONG input_index = (blockIdx.y * input_strides[0] +
blockIdx.x * input_strides[1] +
Expand All @@ -104,63 +105,124 @@ __global__ void Transpose4DKernel(const TArray<int64_t> input_strides, const voi
}
}

bool CanDoTranspose4D(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations) {
bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations) {
if (rank == 4 &&
// the permutations is not on the last dimension.
permutations[rank - 1] == (rank - 1)) {
// The block size will be set based on the last two dimensions of 4D tensor.
permutations[3] == 3) {
// The block size will be set based on the outer-most two dimensions of 4D tensor.
// the number threads per block will be calculated as below.
unsigned int num_elements_per_thread = 4 * sizeof(int) / static_cast<unsigned int>(element_size); // int4 is used in the kernel to access data.
int64_t num_elements_in_last_two_dimensions = input_dims[rank - 2] * input_dims[rank - 1];
int64_t num_elements_in_last_two_dimensions = input_dims[2] * input_dims[3];
int64_t num_threads_per_block = num_elements_in_last_two_dimensions / num_elements_per_thread;

if (((num_elements_in_last_two_dimensions & (num_elements_per_thread - 1)) == 0) &&
num_threads_per_block <= prop.maxThreadsPerBlock &&
num_threads_per_block >= prop.warpSize &&
// num_threads_per_block must be aligned with warp size: 32
((num_threads_per_block & (prop.warpSize - 1)) == 0)) {
// num_threads_per_block must be a multiple of warp size (32)
((num_threads_per_block & (prop.warpSize - 1)) == 0) &&
// input_dims[3] must be a multiple of `num_elements_per_thread`
((input_dims[3] % num_elements_per_thread) == 0)) {
return true;
}
}
return false;
}

Status Transpose4DImpl(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int N) {
Status Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(
cudaStream_t stream, size_t element_size,
const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides,
const void* input_data, const TArray<int64_t>& output_strides,
void* output_data, int N) {
if (element_size != sizeof(int8_t) &&
element_size != sizeof(int16_t) &&
element_size != sizeof(int32_t) &&
element_size != sizeof(int64_t)) {
// User will not hit this as this kernel is for fixed element size tensors only
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
element_size);
}

unsigned int num_elements_per_thread = 4 * sizeof(int) / static_cast<unsigned int>(element_size); // int4 is used in the kernel to access data.
dim3 block_size(static_cast<unsigned int>(input_shape[3] / num_elements_per_thread), static_cast<unsigned int>(input_shape[2]));
dim3 grid_size(static_cast<unsigned int>(input_shape[1]), static_cast<unsigned int>(input_shape[0]));

switch (element_size) {
case sizeof(int8_t):
Transpose4DKernel<sizeof(int8_t)><<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int16_t):
Transpose4DKernel<sizeof(int16_t)><<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int32_t):
Transpose4DKernel<sizeof(int32_t)><<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int64_t):
Transpose4DKernel<sizeof(int64_t)><<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
element_size);
Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim<<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, element_size, N / num_elements_per_thread);

return Status::OK();
}

__global__ void Transpose4DKernelParallelizeOneElementPerThread(
const TArray<int64_t> input_strides, const int8_t* input_data,
const TArray<int64_t> output_strides, int8_t* output_data,
size_t element_size,
CUDA_LONG N) {
CUDA_LONG input_index = blockIdx.y * input_strides[0] +
blockIdx.x * input_strides[1] +
threadIdx.y * input_strides[2] +
threadIdx.x * input_strides[3];

CUDA_LONG output_index = blockIdx.y * output_strides[0] +
blockIdx.x * output_strides[1] +
threadIdx.y * output_strides[2] +
threadIdx.x * output_strides[3];

if (input_index < N && output_index < N) {
const int8_t* input_data_to_be_copied = input_data + (input_index * element_size);
int8_t* output_data_to_be_copied = output_data + (output_index * element_size);

// copy over the bytes
memcpy(output_data_to_be_copied, input_data_to_be_copied, element_size);
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
}
}

bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations) {
if (rank == 4) {
// The block size will be set based on the outer-most two dimensions of 4D tensor.
// the number threads per block will be calculated as below.
int64_t number_of_threads_per_block = input_dims[2] * input_dims[3];

if (number_of_threads_per_block <= prop.maxThreadsPerBlock &&
number_of_threads_per_block >= prop.warpSize &&
// num_threads_per_block must be a multiple of warp size (32)
((number_of_threads_per_block & (prop.warpSize - 1)) == 0)) {
return true;
}
}
return false;
}

Status Transpose4DParallelizeOneElementPerThread(
cudaStream_t stream, size_t element_size,
const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides,
const void* input_data, const TArray<int64_t>& output_strides,
void* output_data, int N) {
if (element_size != sizeof(int8_t) &&
element_size != sizeof(int16_t) &&
element_size != sizeof(int32_t) &&
element_size != sizeof(int64_t)) {
// User will not hit this as this kernel is for fixed element size tensors only
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
element_size);
}

dim3 block_size(static_cast<unsigned int>(input_shape[3]), static_cast<unsigned int>(input_shape[2]));
dim3 grid_size(static_cast<unsigned int>(input_shape[1]), static_cast<unsigned int>(input_shape[0]));

Transpose4DKernelParallelizeOneElementPerThread<<<grid_size, block_size, 0, stream>>>(
input_strides, reinterpret_cast<const int8_t*>(input_data),
output_strides, reinterpret_cast<int8_t*>(output_data),
element_size, N);

return Status::OK();
}

Expand Down
26 changes: 19 additions & 7 deletions onnxruntime/core/providers/cuda/tensor/transpose_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,25 @@ namespace cuda {
bool CanDoTranspose3D(int32_t rank, const std::vector<int64_t>& input_dims, const std::vector<size_t>& permutations);
Status Transpose3DImpl(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides, const void* input_data,
void* output_data, int64_t N);
bool CanDoTranspose4D(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations);
Status Transpose4DImpl(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int N);

bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations);
Status Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape,
const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int N);

bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations);
Status Transpose4DParallelizeOneElementPerThread(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape,
const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int N);

Status TransposeImpl(cudaStream_t stream, size_t element_size, int32_t shape_rank, const TArray<int64_t>& input_strides,
const void* input_data, const TArray<fast_divmod>& fdm_output_strides, void* output_data, int N);
} // namespace cuda
Expand Down
57 changes: 35 additions & 22 deletions onnxruntime/core/providers/rocm/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ Status TransposeWithRocblas(hipStream_t stream, rocblas_handle rocblas_handle, c
HipT* output_data = reinterpret_cast<HipT*>(output.MutableData<T>());
ROCBLAS_RETURN_IF_ERROR(
rocblasTransposeHelper(stream,
rocblas_handle,
rocblas_operation_transpose, rocblas_operation_transpose, M, N,
&one,
input_data,
N,
&zero,
input_data,
N,
output_data,
M));
rocblas_handle,
rocblas_operation_transpose, rocblas_operation_transpose, M, N,
&one,
input_data,
N,
&zero,
input_data,
N,
output_data,
M));
return Status::OK();
}

Expand Down Expand Up @@ -128,25 +128,25 @@ Status Transpose::DoTranspose(const hipDeviceProp_t& prop,
new_permutations[j] -= 1;
}
}
for (auto j = i+1; j < new_rank; j++) {
new_permutations[j-1] = new_permutations[j];
for (auto j = i + 1; j < new_rank; j++) {
new_permutations[j - 1] = new_permutations[j];
}

// update input dims
new_input_dims[prev] *= new_input_dims[curr];
new_input_dims[curr] = 1;
for (auto j = static_cast<int32_t>(curr+1); j < new_rank; j++) {
new_input_dims[j-1] = new_input_dims[j];
for (auto j = static_cast<int32_t>(curr + 1); j < new_rank; j++) {
new_input_dims[j - 1] = new_input_dims[j];
}
new_input_dims[new_rank-1] = 1;
new_input_dims[new_rank - 1] = 1;

// update output dims
new_output_dims[i-1] *= new_output_dims[i];
new_output_dims[i - 1] *= new_output_dims[i];
new_output_dims[i] = 1;
for (auto j = i+1; j < new_rank; j++) {
new_output_dims[j-1] = new_output_dims[j];
for (auto j = i + 1; j < new_rank; j++) {
new_output_dims[j - 1] = new_output_dims[j];
}
new_output_dims[new_rank-1] = 1;
new_output_dims[new_rank - 1] = 1;

new_rank--;
}
Expand All @@ -166,13 +166,26 @@ Status Transpose::DoTranspose(const hipDeviceProp_t& prop,
if (CanDoTranspose3D(new_rank, new_input_dims, new_permutations)) {
return Transpose3DImpl(stream, element_size, input_shape, tmp_input_strides,
input.DataRaw(), output.MutableDataRaw(), output.Shape().Size());
} else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) {
} else if (CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
prop, element_size, new_rank, new_input_dims, new_permutations)) {
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DImpl(stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), output.Shape().Size());
return Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
} else if (CanDoTranspose4DParallelizeOneElementPerThread(
prop, element_size, new_rank, new_input_dims, new_permutations)) {
// Trying to see if we can still do (best effort) more optimized transposing
// for the 4-D case before falling back to the generic case
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DParallelizeOneElementPerThread(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
}

// General cases
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/test/providers/cpu/tensor/transpose_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,26 +590,34 @@ static void TestTranspose(
test.CompareWithCPU(kGpuExecutionProvider, error_tolerance);
}

TEST(TransposeOpTest, Transpose0213) {
TEST(TransposeOpTest, Transpose0213) { // Will trigger Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim()
const std::vector<int64_t> X_dims{64, 128, 16, 64};
const std::vector<int64_t> perm{0, 2, 1, 3};
const std::vector<int64_t> Y_dims{64, 16, 128, 64};
TestTranspose(perm, X_dims, Y_dims);
}

TEST(TransposeOpTest, Transpose0231) {
TEST(TransposeOpTest, Transpose0213_V2) { // Will trigger Transpose4DParallelizeOneElementPerThread()
const std::vector<int64_t> X_dims{64, 128, 64, 2};
const std::vector<int64_t> perm{0, 2, 1, 3};
const std::vector<int64_t> Y_dims{64, 64, 128, 2};
TestTranspose(perm, X_dims, Y_dims);
}

TEST(TransposeOpTest, Transpose0231) { // Will trigger Transpose3DImpl() because of "flattening" of dims 2 and 3 into one dim
const std::vector<int64_t> X_dims{64, 128, 16, 64};
const std::vector<int64_t> perm{0, 2, 3, 1};
const std::vector<int64_t> Y_dims{64, 16, 64, 128};
TestTranspose(perm, X_dims, Y_dims);
}

TEST(TransposeOpTest, Transpose0312) {
TEST(TransposeOpTest, Transpose0312) { // Will trigger Transpose3DImpl() because of "flattening" of dims 1 and 2 into one dim
const std::vector<int64_t> X_dims{64, 16, 64, 128};
const std::vector<int64_t> perm{0, 3, 1, 2};
const std::vector<int64_t> Y_dims{64, 128, 16, 64};
TestTranspose(perm, X_dims, Y_dims);
}

#endif

} // namespace test
Expand Down