Skip to content

Commit

Permalink
Fix bug in Transpose CUDA kernel (#7329)
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 authored May 27, 2021
1 parent 883923a commit 7380219
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 62 deletions.
19 changes: 16 additions & 3 deletions onnxruntime/core/providers/cuda/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,26 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop,
if (CanDoTranspose3D(new_rank, new_input_dims, new_permutations)) {
return Transpose3DImpl(stream, element_size, input_shape, tmp_input_strides,
input.DataRaw(), output.MutableDataRaw(), output.Shape().Size());
} else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) {
} else if (CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(
prop, element_size, new_rank, new_input_dims, new_permutations)) {
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DImpl(stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
return Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
} else if (CanDoTranspose4DParallelizeOneElementPerThread(
prop, element_size, new_rank, new_input_dims, new_permutations)) {
// Trying to see if we can still do (best effort) more optimized transposing
// for the 4-D case before falling back to the generic case
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DParallelizeOneElementPerThread(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
}

// General cases
Expand Down
136 changes: 109 additions & 27 deletions onnxruntime/core/providers/cuda/tensor/transpose_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ Status Transpose3DImpl(cudaStream_t stream, size_t element_size,
}

template <int element_size>
__global__ void Transpose4DKernel(const TArray<int64_t> input_strides, const void* input_data,
const TArray<int64_t> output_strides, void* output_data,
CUDA_LONG N) {
__global__ void Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim(
const TArray<int64_t> input_strides, const void* input_data,
const TArray<int64_t> output_strides, void* output_data,
CUDA_LONG N) {
// output coordinates will be: blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x
CUDA_LONG input_index = (blockIdx.y * input_strides[0] +
blockIdx.x * input_strides[1] +
Expand All @@ -104,66 +105,147 @@ __global__ void Transpose4DKernel(const TArray<int64_t> input_strides, const voi
}
}

bool CanDoTranspose4D(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations) {
bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations) {
if (rank == 4 &&
// the permutations is not on the last dimension.
permutations[rank - 1] == (rank - 1)) {
// The block size will be set based on the last two dimensions of 4D tensor.
permutations[3] == 3) {
// The block size will be set based on the outer-most two dimensions of 4D tensor.
// the number threads per block will be calculated as below.
unsigned int num_elements_per_thread = 4 * sizeof(int) / static_cast<unsigned int>(element_size); // int4 is used in the kernel to access data.
int64_t num_elements_in_last_two_dimensions = input_dims[rank - 2] * input_dims[rank - 1];
int64_t num_elements_in_last_two_dimensions = input_dims[2] * input_dims[3];
int64_t num_threads_per_block = num_elements_in_last_two_dimensions / num_elements_per_thread;

if (((num_elements_in_last_two_dimensions & (num_elements_per_thread - 1)) == 0) &&
num_threads_per_block <= prop.maxThreadsPerBlock &&
num_threads_per_block >= prop.warpSize &&
// num_threads_per_block must be aligned with warp size: 32
((num_threads_per_block & (prop.warpSize - 1)) == 0)) {
// num_threads_per_block must be a multiple of warp size (32)
((num_threads_per_block & (prop.warpSize - 1)) == 0) &&
// input_dims[3] must be a multiple of `num_elements_per_thread`
((input_dims[3] % num_elements_per_thread) == 0)) {
return true;
}
}
return false;
}

Status Transpose4DImpl(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int N) {
Status Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(
cudaStream_t stream, size_t element_size,
const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides,
const void* input_data, const TArray<int64_t>& output_strides,
void* output_data, int N) {
unsigned int num_elements_per_thread = 4 * sizeof(int) / static_cast<unsigned int>(element_size); // int4 is used in the kernel to access data.
dim3 block_size(static_cast<unsigned int>(input_shape[3] / num_elements_per_thread), static_cast<unsigned int>(input_shape[2]));
dim3 grid_size(static_cast<unsigned int>(input_shape[1]), static_cast<unsigned int>(input_shape[0]));

switch (element_size) {
case sizeof(int8_t):
Transpose4DKernel<sizeof(int8_t)><<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim<sizeof(int8_t)>
<<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int16_t):
Transpose4DKernel<sizeof(int16_t)><<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim<sizeof(int16_t)>
<<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int32_t):
Transpose4DKernel<sizeof(int32_t)><<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim<sizeof(int32_t)>
<<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
break;
case sizeof(int64_t):
Transpose4DKernel<sizeof(int64_t)><<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim<sizeof(int64_t)>
<<<grid_size, block_size, 0, stream>>>(
input_strides, input_data,
output_strides, output_data, N / num_elements_per_thread);
break;
default:
// User will not hit this as this kernel is for fixed element size tensors only
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
element_size);
}

return Status::OK();
}

__global__ void Transpose4DKernelParallelizeOneElementPerThread(
const TArray<int64_t> input_strides, const int8_t* input_data,
const TArray<int64_t> output_strides, int8_t* output_data,
size_t element_size,
CUDA_LONG N) {
CUDA_LONG input_index = blockIdx.y * input_strides[0] +
blockIdx.x * input_strides[1] +
threadIdx.y * input_strides[2] +
threadIdx.x * input_strides[3];

CUDA_LONG output_index = blockIdx.y * output_strides[0] +
blockIdx.x * output_strides[1] +
threadIdx.y * output_strides[2] +
threadIdx.x * output_strides[3];

if (input_index < N && output_index < N) {
const int8_t* input_data_to_be_copied = input_data + (input_index * element_size);
int8_t* output_data_to_be_copied = output_data + (output_index * element_size);

// copy over the bytes
for (size_t iter = 0; iter < element_size; ++iter) {
*output_data_to_be_copied++ = *input_data_to_be_copied++;
}
}
}

bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations) {
if (rank == 4) {
// The block size will be set based on the outer-most two dimensions of 4D tensor.
// the number threads per block will be calculated as below.
int64_t number_of_threads_per_block = input_dims[2] * input_dims[3];

if (number_of_threads_per_block <= prop.maxThreadsPerBlock &&
number_of_threads_per_block >= prop.warpSize &&
// num_threads_per_block must be a multiple of warp size (32)
((number_of_threads_per_block & (prop.warpSize - 1)) == 0)) {
return true;
}
}
return false;
}

Status Transpose4DParallelizeOneElementPerThread(
cudaStream_t stream, size_t element_size,
const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides,
const void* input_data, const TArray<int64_t>& output_strides,
void* output_data, int N) {
if (element_size != sizeof(int8_t) &&
element_size != sizeof(int16_t) &&
element_size != sizeof(int32_t) &&
element_size != sizeof(int64_t)) {
// User will not hit this as this kernel is for fixed element size tensors only
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
element_size);
}

dim3 block_size(static_cast<unsigned int>(input_shape[3]), static_cast<unsigned int>(input_shape[2]));
dim3 grid_size(static_cast<unsigned int>(input_shape[1]), static_cast<unsigned int>(input_shape[0]));

Transpose4DKernelParallelizeOneElementPerThread<<<grid_size, block_size, 0, stream>>>(
input_strides, reinterpret_cast<const int8_t*>(input_data),
output_strides, reinterpret_cast<int8_t*>(output_data),
element_size, N);

return Status::OK();
}

template <typename T>
__global__ void TransposeKernel(int32_t shape_rank, const TArray<int64_t> input_strides,
const T* input_data, const TArray<fast_divmod> output_strides, T* output_data, CUDA_LONG N) {
Expand Down
26 changes: 19 additions & 7 deletions onnxruntime/core/providers/cuda/tensor/transpose_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,25 @@ namespace cuda {
bool CanDoTranspose3D(int32_t rank, const std::vector<int64_t>& input_dims, const std::vector<size_t>& permutations);
Status Transpose3DImpl(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides, const void* input_data,
void* output_data, int64_t N);
bool CanDoTranspose4D(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations);
Status Transpose4DImpl(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape, const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int N);

bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations);
Status Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape,
const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int N);

bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop,
size_t element_size,
int32_t rank,
const std::vector<int64_t>& input_dims,
const std::vector<size_t>& permutations);
Status Transpose4DParallelizeOneElementPerThread(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape,
const TArray<int64_t>& input_strides, const void* input_data,
const TArray<int64_t>& output_strides, void* output_data, int N);

Status TransposeImpl(cudaStream_t stream, size_t element_size, int32_t shape_rank, const TArray<int64_t>& input_strides,
const void* input_data, const TArray<fast_divmod>& fdm_output_strides, void* output_data, int N);
} // namespace cuda
Expand Down
57 changes: 35 additions & 22 deletions onnxruntime/core/providers/rocm/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ Status TransposeWithRocblas(hipStream_t stream, rocblas_handle rocblas_handle, c
HipT* output_data = reinterpret_cast<HipT*>(output.MutableData<T>());
ROCBLAS_RETURN_IF_ERROR(
rocblasTransposeHelper(stream,
rocblas_handle,
rocblas_operation_transpose, rocblas_operation_transpose, M, N,
&one,
input_data,
N,
&zero,
input_data,
N,
output_data,
M));
rocblas_handle,
rocblas_operation_transpose, rocblas_operation_transpose, M, N,
&one,
input_data,
N,
&zero,
input_data,
N,
output_data,
M));
return Status::OK();
}

Expand Down Expand Up @@ -128,25 +128,25 @@ Status Transpose::DoTranspose(const hipDeviceProp_t& prop,
new_permutations[j] -= 1;
}
}
for (auto j = i+1; j < new_rank; j++) {
new_permutations[j-1] = new_permutations[j];
for (auto j = i + 1; j < new_rank; j++) {
new_permutations[j - 1] = new_permutations[j];
}

// update input dims
new_input_dims[prev] *= new_input_dims[curr];
new_input_dims[curr] = 1;
for (auto j = static_cast<int32_t>(curr+1); j < new_rank; j++) {
new_input_dims[j-1] = new_input_dims[j];
for (auto j = static_cast<int32_t>(curr + 1); j < new_rank; j++) {
new_input_dims[j - 1] = new_input_dims[j];
}
new_input_dims[new_rank-1] = 1;
new_input_dims[new_rank - 1] = 1;

// update output dims
new_output_dims[i-1] *= new_output_dims[i];
new_output_dims[i - 1] *= new_output_dims[i];
new_output_dims[i] = 1;
for (auto j = i+1; j < new_rank; j++) {
new_output_dims[j-1] = new_output_dims[j];
for (auto j = i + 1; j < new_rank; j++) {
new_output_dims[j - 1] = new_output_dims[j];
}
new_output_dims[new_rank-1] = 1;
new_output_dims[new_rank - 1] = 1;

new_rank--;
}
Expand All @@ -166,13 +166,26 @@ Status Transpose::DoTranspose(const hipDeviceProp_t& prop,
if (CanDoTranspose3D(new_rank, new_input_dims, new_permutations)) {
return Transpose3DImpl(stream, element_size, input_shape, tmp_input_strides,
input.DataRaw(), output.MutableDataRaw(), output.Shape().Size());
} else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) {
} else if (CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(
prop, element_size, new_rank, new_input_dims, new_permutations)) {
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DImpl(stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), output.Shape().Size());
return Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
} else if (CanDoTranspose4DParallelizeOneElementPerThread(
prop, element_size, new_rank, new_input_dims, new_permutations)) {
// Trying to see if we can still do (best effort) more optimized transposing
// for the 4-D case before falling back to the generic case
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DParallelizeOneElementPerThread(
stream, element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), gsl::narrow<int>(output.Shape().Size()));
}

// General cases
Expand Down
Loading

0 comments on commit 7380219

Please sign in to comment.