From 46823fcbfd0c5ae7f296ea1482dc0d67db7eb8d3 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 13 Apr 2021 06:46:09 -0700 Subject: [PATCH 1/9] Fix bug in Transpose CUDA --- .../providers/cuda/tensor/transpose_impl.cu | 68 ++++++--- .../providers/cpu/tensor/transpose_test.cc | 133 ++++++++++-------- 2 files changed, 119 insertions(+), 82 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 10611c9cd9d3a..7f562a28e816f 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -79,22 +79,39 @@ Status Transpose3DImpl(cudaStream_t stream, size_t element_size, return Status::OK(); } -template __global__ void Transpose4DKernel(const TArray input_strides, const void* input_data, const TArray output_strides, void* output_data, + unsigned int num_elements_per_thread, bool multiple_elements_per_thread_in_last_dim, CUDA_LONG N) { - // output coordinates will be: blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x - CUDA_LONG input_index = (blockIdx.y * input_strides[0] + - blockIdx.x * input_strides[1] + - threadIdx.y * input_strides[2]) / - (4 * sizeof(int) / element_size) + - threadIdx.x * input_strides[3]; - - CUDA_LONG output_index = (blockIdx.y * output_strides[0] + - blockIdx.x * output_strides[1] + - threadIdx.y * output_strides[2]) / - (4 * sizeof(int) / element_size) + - threadIdx.x * output_strides[3]; + CUDA_LONG input_index = 0; + CUDA_LONG output_index = 0; + + if (multiple_elements_per_thread_in_last_dim) { + // output coordinates will be: blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x + input_index = (blockIdx.y * input_strides[0] + + blockIdx.x * input_strides[1] + + threadIdx.y * input_strides[2]) / + num_elements_per_thread + + threadIdx.x * input_strides[3]; + + output_index = (blockIdx.y * output_strides[0] + + blockIdx.x * output_strides[1] + + threadIdx.y * output_strides[2]) / + num_elements_per_thread + + threadIdx.x * output_strides[3]; + } else { + input_index = (blockIdx.y * input_strides[0] + + blockIdx.x * input_strides[1] + + threadIdx.x * input_strides[3]) / + num_elements_per_thread + + threadIdx.y * input_strides[2]; + + output_index = (blockIdx.y * output_strides[0] + + blockIdx.x * output_strides[1] + + threadIdx.x * output_strides[3]) / + num_elements_per_thread + + threadIdx.y * output_strides[2]; + } const int4* v_input = reinterpret_cast(input_data); int4* v_output = reinterpret_cast(output_data); @@ -133,28 +150,37 @@ Status Transpose4DImpl(cudaStream_t stream, size_t element_size, const TArray& output_strides, void* output_data, int N) { unsigned int num_elements_per_thread = 4 * sizeof(int) / static_cast(element_size); // int4 is used in the kernel to access data. dim3 block_size(static_cast(input_shape[3] / num_elements_per_thread), static_cast(input_shape[2])); + bool multiple_elements_per_thread_in_last_dim = true; + + if (block_size.x == 0) { + // Entering this means that input_shape[3] was less than num_elements_per_thread, + // hence have a thread process multiple elements in axis = 2 instead + block_size.x = static_cast(input_shape[3]); + block_size.y = static_cast(input_shape[2] / num_elements_per_thread); + multiple_elements_per_thread_in_last_dim = false; + } dim3 grid_size(static_cast(input_shape[1]), static_cast(input_shape[0])); switch (element_size) { case sizeof(int8_t): - Transpose4DKernel<<>>( + Transpose4DKernel<<>>( input_strides, input_data, - output_strides, output_data, N / num_elements_per_thread); + output_strides, output_data, num_elements_per_thread, multiple_elements_per_thread_in_last_dim, N / num_elements_per_thread); break; case sizeof(int16_t): - Transpose4DKernel<<>>( + Transpose4DKernel<<>>( input_strides, input_data, - output_strides, output_data, N / num_elements_per_thread); + output_strides, output_data, num_elements_per_thread, multiple_elements_per_thread_in_last_dim, N / num_elements_per_thread); break; case sizeof(int32_t): - Transpose4DKernel<<>>( + Transpose4DKernel<<>>( input_strides, input_data, - output_strides, output_data, N / num_elements_per_thread); + output_strides, output_data, num_elements_per_thread, multiple_elements_per_thread_in_last_dim, N / num_elements_per_thread); break; case sizeof(int64_t): - Transpose4DKernel<<>>( + Transpose4DKernel<<>>( input_strides, input_data, - output_strides, output_data, N / num_elements_per_thread); + output_strides, output_data, num_elements_per_thread, multiple_elements_per_thread_in_last_dim, N / num_elements_per_thread); break; default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ", diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index b971d85072f8c..5d740edf31662 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -30,7 +30,7 @@ void TransposeTest(std::vector& input_shape, std::vector& input_vals, std::vector* p_perm, std::vector expected_shape, - std::initializer_list& expected_vals, + std::vector& expected_vals, bool is_tensorrt_supported = true, bool is_openvino_supported = true) { OpTester test("Transpose"); @@ -59,7 +59,7 @@ TEST(TransposeOpTest, TwoDimNoAttr) { 4.0f, 5.0f, 6.0f}; std::vector expected_shape({3, 2}); - auto expected_vals = { + std::vector expected_vals = { 1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f}; @@ -74,7 +74,7 @@ TEST(TransposeOpTest, TwoDimNoAttrStr) { "4", "5", "6"}; std::vector expected_shape({3, 2}); - std::initializer_list expected_vals = { + std::vector expected_vals = { "1", "4", "2", "5", "3", "6"}; @@ -90,9 +90,9 @@ TEST(TransposeOpTest, TwoDim) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - auto expected_vals = {1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f}; + std::vector expected_vals = {1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -104,9 +104,9 @@ TEST(TransposeOpTest, TwoDim_double) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::initializer_list expected_vals = {1.0, 4.0, - 2.0, 5.0, - 3.0, 6.0}; + std::vector expected_vals = {1.0, 4.0, + 2.0, 5.0, + 3.0, 6.0}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -118,9 +118,9 @@ TEST(TransposeOpTest, TwoDim_int32) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::initializer_list expected_vals = {1, 4, - 2, 5, - 3, 6}; + std::vector expected_vals = {1, 4, + 2, 5, + 3, 6}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -133,7 +133,7 @@ TEST(TransposeOpTest, TwoDim_int16) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::initializer_list expected_vals = { + std::vector expected_vals = { 1, 4, 2, 5, 3, 6}; @@ -149,7 +149,7 @@ TEST(TransposeOpTest, TwoDim_mlfloat16) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::initializer_list expected_vals = + std::vector expected_vals = {MLFloat16{static_cast(1)}, MLFloat16{static_cast(4)}, MLFloat16{static_cast(2)}, MLFloat16{static_cast(5)}, MLFloat16{static_cast(3)}, MLFloat16{static_cast(6)}}; @@ -164,9 +164,9 @@ TEST(TransposeOpTest, TwoDim_int8) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::initializer_list expected_vals = {1, 4, - 2, 5, - 3, 6}; + std::vector expected_vals = {1, 4, + 2, 5, + 3, 6}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); } @@ -179,7 +179,7 @@ TEST(TransposeOpTest, TwoDimStr) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::initializer_list expected_vals = { + std::vector expected_vals = { "1", "4", "2", "5", "3", "6"}; @@ -205,22 +205,23 @@ TEST(TransposeOpTest, ThreeDim) { std::vector perm = {0, 2, 1}; std::vector expected_shape({4, 3, 2}); - auto expected_vals = { - 1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f, + std::vector + expected_vals = { + 1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f, - 1.1f, 4.1f, - 2.1f, 5.1f, - 3.1f, 6.1f, + 1.1f, 4.1f, + 2.1f, 5.1f, + 3.1f, 6.1f, - 1.2f, 4.2f, - 2.2f, 5.2f, - 3.2f, 6.2f, + 1.2f, 4.2f, + 2.2f, 5.2f, + 3.2f, 6.2f, - 1.3f, 4.3f, - 2.3f, 5.3f, - 3.3f, 6.3f}; + 1.3f, 4.3f, + 2.3f, 5.3f, + 3.3f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); //TensorRT: illegal error } @@ -243,7 +244,7 @@ TEST(TransposeOpTest, ThreeDimSuffix) { std::vector perm = {1, 0, 2}; std::vector expected_shape({2, 4, 3}); - auto expected_vals = { + std::vector expected_vals = { 1.0f, 2.0f, 3.0f, 1.1f, 2.1f, 3.1f, 1.2f, 2.2f, 3.2f, @@ -274,7 +275,7 @@ TEST(TransposeOpTest, TransposeReshape) { std::vector perm = {1, 3, 2, 4, 0}; std::vector expected_shape({4, 1, 2, 3, 1}); - auto expected_vals = { + std::vector expected_vals = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, @@ -307,7 +308,7 @@ TEST(TransposeOpTest, ThreeDimStr) { std::vector perm = {0, 2, 1}; std::vector expected_shape({4, 3, 2}); - std::initializer_list expected_vals = { + std::vector expected_vals = { "1", "4", "2", "5", "3", "6", @@ -337,7 +338,7 @@ static void NumericNCHW2NHWC() { std::vector perm = {0, 2, 3, 1}; std::vector expected_shape({1, 2, 2, 3}); - std::initializer_list expected_vals = { + std::vector expected_vals = { 1, 5, 9, 2, 6, 10, 3, 7, 11, @@ -361,7 +362,7 @@ TEST(TransposeOpTest, NCHW2NHWCStr) { std::vector perm = {0, 2, 3, 1}; std::vector expected_shape({1, 2, 2, 3}); - std::initializer_list expected_vals = { + std::vector expected_vals = { "1", "5", "9", "2", "6", "10", "3", "7", "11", @@ -388,7 +389,7 @@ static void NumericNHWC2NCHW() { std::vector perm = {0, 3, 1, 2}; std::vector expected_shape({2, 2, 2, 2}); - std::initializer_list expected_vals = { + std::vector expected_vals = { 1, 3, 5, 7, @@ -421,7 +422,7 @@ TEST(TransposeOpTest, NHWC2NCHW_String) { std::vector perm = {0, 3, 1, 2}; std::vector expected_shape({1, 3, 2, 2}); - std::initializer_list expected_vals = { + std::vector expected_vals = { "1", "4", "7", "10", "2", "5", "8", "11", "3", "6", "9", "12"}; @@ -447,7 +448,7 @@ TEST(TransposeOpTest, SingleAxisMovingInwardsBlockCopy) { std::vector perm = {1, 2, 0, 3}; std::vector expected_shape({2, 2, 2, 2}); - std::initializer_list expected_vals = { + std::vector expected_vals = { 1, 2, 9, 10, @@ -471,17 +472,17 @@ TEST(TransposeOpTest, NDim) { 13.0f, 14.0f, 15.0f, 16.0f}; std::vector perm = {1, 0, 2, 3}; - auto expected_vals = {1.0f, 2.0f, 3.0f, 4.0f, - 9.0f, 10.0f, 11.0f, 12.0f, - 5.0f, 6.0f, 7.0f, 8.0f, - 13.0f, 14.0f, 15.0f, 16.0f}; + std::vector expected_vals = {1.0f, 2.0f, 3.0f, 4.0f, + 9.0f, 10.0f, 11.0f, 12.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 13.0f, 14.0f, 15.0f, 16.0f}; TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals); perm = {1, 0, 3, 2}; - auto expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f, - 9.0f, 11.0f, 10.0f, 12.0f, - 5.0f, 7.0f, 6.0f, 8.0f, - 13.0f, 15.0f, 14.0f, 16.0f}; + std::vector expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f, + 9.0f, 11.0f, 10.0f, 12.0f, + 5.0f, 7.0f, 6.0f, 8.0f, + 13.0f, 15.0f, 14.0f, 16.0f}; TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals2); } @@ -493,11 +494,11 @@ TEST(TransposeOpTest, DoTransposeImpl) { } std::vector perm = {2, 1, 0, 3}; std::vector expected_shape({1, 2, 5, 3}); - auto expected_vals = {0.0f, 1.0f, 2.0f, 6.0f, 7.0f, 8.0f, - 12.0f, 13.0f, 14.0f, 18.0f, 19.0f, 20.0f, - 24.0f, 25.0f, 26.0f, 3.0f, 4.0f, 5.0f, - 9.0f, 10.0f, 11.0f, 15.0f, 16.0f, 17.0f, - 21.0f, 22.0f, 23.0f, 27.0f, 28.0f, 29.0f}; + std::vector expected_vals = {0.0f, 1.0f, 2.0f, 6.0f, 7.0f, 8.0f, + 12.0f, 13.0f, 14.0f, 18.0f, 19.0f, 20.0f, + 24.0f, 25.0f, 26.0f, 3.0f, 4.0f, 5.0f, + 9.0f, 10.0f, 11.0f, 15.0f, 16.0f, 17.0f, + 21.0f, 22.0f, 23.0f, 27.0f, 28.0f, 29.0f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -509,11 +510,11 @@ TEST(TransposeOpTest, DoTransposeImplString) { } std::vector perm = {2, 1, 0, 3}; std::vector expected_shape({1, 2, 5, 3}); - std::initializer_list expected_vals = {"n0", "n1", "n2", "n6", "n7", "n8", - "n12", "n13", "n14", "n18", "n19", "n20", - "n24", "n25", "n26", "n3", "n4", "n5", - "n9", "n10", "n11", "n15", "n16", "n17", - "n21", "n22", "n23", "n27", "n28", "n29"}; + std::vector expected_vals = {"n0", "n1", "n2", "n6", "n7", "n8", + "n12", "n13", "n14", "n18", "n19", "n20", + "n24", "n25", "n26", "n3", "n4", "n5", + "n9", "n10", "n11", "n15", "n16", "n17", + "n21", "n22", "n23", "n27", "n28", "n29"}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -526,10 +527,10 @@ TEST(TransposeOpTest, DoTransposeEltWise) { 13.0f, 14.0f, 15.0f, 16.0f}; std::vector perm = {1, 0, 3, 2}; - auto expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f, - 9.0f, 11.0f, 10.0f, 12.0f, - 5.0f, 7.0f, 6.0f, 8.0f, - 13.0f, 15.0f, 14.0f, 16.0f}; + std::vector expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f, + 9.0f, 11.0f, 10.0f, 12.0f, + 5.0f, 7.0f, 6.0f, 8.0f, + 13.0f, 15.0f, 14.0f, 16.0f}; TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals2); // Specific test which tests that function DoTransposeEltWise does not @@ -612,5 +613,15 @@ TEST(TransposeOpTest, Transpose0312) { } #endif +TEST(TransposeOpTest, FloatDataPerm0213) { + std::vector input_shape({1, 1, 128, 1}); + std::vector input_vals(128, 1); + + std::vector expected_shape({1, 128, 1, 1}); + + std::vector perm = {0, 2, 1, 3}; + TransposeTest(input_shape, input_vals, &perm, expected_shape, input_vals); +} + } // namespace test } // namespace onnxruntime From e772c3d943cfe0404355ded22828c1ae55cc97b7 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 13 Apr 2021 08:35:06 -0700 Subject: [PATCH 2/9] Add comment --- onnxruntime/test/providers/cpu/tensor/transpose_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 5d740edf31662..41503d3464a76 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -613,6 +613,7 @@ TEST(TransposeOpTest, Transpose0312) { } #endif +// Test crafted with specific shape to trigger bug reported in GH issue : 7316 TEST(TransposeOpTest, FloatDataPerm0213) { std::vector input_shape({1, 1, 128, 1}); std::vector input_vals(128, 1); From cb74fd3e17fc513a0b55ec15fbd2639e104f0ad0 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 13 Apr 2021 08:36:38 -0700 Subject: [PATCH 3/9] Add comment --- onnxruntime/test/providers/cpu/tensor/transpose_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 41503d3464a76..e968294cfdf68 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -613,7 +613,7 @@ TEST(TransposeOpTest, Transpose0312) { } #endif -// Test crafted with specific shape to trigger bug reported in GH issue : 7316 +// Test crafted with specific shape based on bug reported in GH issue : 7316 TEST(TransposeOpTest, FloatDataPerm0213) { std::vector input_shape({1, 1, 128, 1}); std::vector input_vals(128, 1); From dedaae13a8439546800a73c6309d6ea9937f90c1 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 24 May 2021 12:16:33 -0700 Subject: [PATCH 4/9] update --- .../core/providers/cuda/tensor/transpose.cc | 37 ++-- .../providers/cuda/tensor/transpose_impl.cu | 192 +++++++++++------- .../providers/cuda/tensor/transpose_impl.h | 26 ++- .../providers/cpu/tensor/transpose_test.cc | 148 +++++++------- 4 files changed, 230 insertions(+), 173 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.cc b/onnxruntime/core/providers/cuda/tensor/transpose.cc index f5fb7c01473d7..a1107093dd08c 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.cc +++ b/onnxruntime/core/providers/cuda/tensor/transpose.cc @@ -128,25 +128,25 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop, new_permutations[j] -= 1; } } - for (auto j = i+1; j < new_rank; j++) { - new_permutations[j-1] = new_permutations[j]; + for (auto j = i + 1; j < new_rank; j++) { + new_permutations[j - 1] = new_permutations[j]; } // update input dims new_input_dims[prev] *= new_input_dims[curr]; new_input_dims[curr] = 1; - for (auto j = static_cast(curr+1); j < new_rank; j++) { - new_input_dims[j-1] = new_input_dims[j]; + for (auto j = static_cast(curr + 1); j < new_rank; j++) { + new_input_dims[j - 1] = new_input_dims[j]; } - new_input_dims[new_rank-1] = 1; + new_input_dims[new_rank - 1] = 1; // update output dims - new_output_dims[i-1] *= new_output_dims[i]; + new_output_dims[i - 1] *= new_output_dims[i]; new_output_dims[i] = 1; - for (auto j = i+1; j < new_rank; j++) { - new_output_dims[j-1] = new_output_dims[j]; + for (auto j = i + 1; j < new_rank; j++) { + new_output_dims[j - 1] = new_output_dims[j]; } - new_output_dims[new_rank-1] = 1; + new_output_dims[new_rank - 1] = 1; new_rank--; } @@ -166,13 +166,26 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop, if (CanDoTranspose3D(new_rank, new_input_dims, new_permutations)) { return Transpose3DImpl(stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), output.MutableDataRaw(), output.Shape().Size()); - } else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) { + } else if (CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim( + prop, element_size, new_rank, new_input_dims, new_permutations)) { TArray tmp_output_strides(new_rank); for (auto i = 0; i < new_rank; i++) { tmp_output_strides[i] = new_output_strides[new_permutations[i]]; } - return Transpose4DImpl(stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), - tmp_output_strides, output.MutableDataRaw(), gsl::narrow(output.Shape().Size())); + return Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim( + stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), + tmp_output_strides, output.MutableDataRaw(), gsl::narrow(output.Shape().Size())); + } else if (CanDoTranspose4DParallelizeOneElementPerThread( + prop, element_size, new_rank, new_input_dims, new_permutations)) { + // Trying to see if we can still do (best effort) more optimized transposing + // for the 4-D case before falling back to the generic case + TArray tmp_output_strides(new_rank); + for (auto i = 0; i < new_rank; i++) { + tmp_output_strides[i] = new_output_strides[new_permutations[i]]; + } + return Transpose4DParallelizeOneElementPerThread( + stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), + tmp_output_strides, output.MutableDataRaw(), gsl::narrow(output.Shape().Size())); } // General cases diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 7f562a28e816f..df1d90c40a3dc 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -79,39 +79,23 @@ Status Transpose3DImpl(cudaStream_t stream, size_t element_size, return Status::OK(); } -__global__ void Transpose4DKernel(const TArray input_strides, const void* input_data, - const TArray output_strides, void* output_data, - unsigned int num_elements_per_thread, bool multiple_elements_per_thread_in_last_dim, - CUDA_LONG N) { - CUDA_LONG input_index = 0; - CUDA_LONG output_index = 0; - - if (multiple_elements_per_thread_in_last_dim) { - // output coordinates will be: blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x - input_index = (blockIdx.y * input_strides[0] + - blockIdx.x * input_strides[1] + - threadIdx.y * input_strides[2]) / - num_elements_per_thread + - threadIdx.x * input_strides[3]; - - output_index = (blockIdx.y * output_strides[0] + - blockIdx.x * output_strides[1] + - threadIdx.y * output_strides[2]) / - num_elements_per_thread + - threadIdx.x * output_strides[3]; - } else { - input_index = (blockIdx.y * input_strides[0] + - blockIdx.x * input_strides[1] + - threadIdx.x * input_strides[3]) / - num_elements_per_thread + - threadIdx.y * input_strides[2]; - - output_index = (blockIdx.y * output_strides[0] + - blockIdx.x * output_strides[1] + - threadIdx.x * output_strides[3]) / - num_elements_per_thread + - threadIdx.y * output_strides[2]; - } +__global__ void Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim( + const TArray input_strides, const void* input_data, + const TArray output_strides, void* output_data, + size_t element_size, + CUDA_LONG N) { + // output coordinates will be: blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x + CUDA_LONG input_index = (blockIdx.y * input_strides[0] + + blockIdx.x * input_strides[1] + + threadIdx.y * input_strides[2]) / + (4 * sizeof(int) / element_size) + + threadIdx.x * input_strides[3]; + + CUDA_LONG output_index = (blockIdx.y * output_strides[0] + + blockIdx.x * output_strides[1] + + threadIdx.y * output_strides[2]) / + (4 * sizeof(int) / element_size) + + threadIdx.x * output_strides[3]; const int4* v_input = reinterpret_cast(input_data); int4* v_output = reinterpret_cast(output_data); @@ -121,71 +105,123 @@ __global__ void Transpose4DKernel(const TArray input_strides, const voi } } -bool CanDoTranspose4D(const cudaDeviceProp& prop, - size_t element_size, - int32_t rank, - const std::vector& input_dims, - const std::vector& permutations) { +bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const cudaDeviceProp& prop, + size_t element_size, + int32_t rank, + const std::vector& input_dims, + const std::vector& permutations) { if (rank == 4 && // the permutations is not on the last dimension. - permutations[rank - 1] == (rank - 1)) { - // The block size will be set based on the last two dimensions of 4D tensor. + permutations[3] == 3) { + // The block size will be set based on the outer-most two dimensions of 4D tensor. // the number threads per block will be calculated as below. unsigned int num_elements_per_thread = 4 * sizeof(int) / static_cast(element_size); // int4 is used in the kernel to access data. - int64_t num_elements_in_last_two_dimensions = input_dims[rank - 2] * input_dims[rank - 1]; + int64_t num_elements_in_last_two_dimensions = input_dims[2] * input_dims[3]; int64_t num_threads_per_block = num_elements_in_last_two_dimensions / num_elements_per_thread; if (((num_elements_in_last_two_dimensions & (num_elements_per_thread - 1)) == 0) && num_threads_per_block <= prop.maxThreadsPerBlock && num_threads_per_block >= prop.warpSize && - // num_threads_per_block must be aligned with warp size: 32 - ((num_threads_per_block & (prop.warpSize - 1)) == 0)) { + // num_threads_per_block must be a multiple of warp size (32) + ((num_threads_per_block & (prop.warpSize - 1)) == 0) && + // input_dims[3] must be a multiple of `num_elements_per_thread` + ((input_dims[3] % num_elements_per_thread) == 0)) { return true; } } return false; } -Status Transpose4DImpl(cudaStream_t stream, size_t element_size, const TArray& input_shape, const TArray& input_strides, const void* input_data, - const TArray& output_strides, void* output_data, int N) { +Status Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim( + cudaStream_t stream, size_t element_size, + const TArray& input_shape, const TArray& input_strides, + const void* input_data, const TArray& output_strides, + void* output_data, int N) { + if (element_size != sizeof(int8_t) && + element_size != sizeof(int16_t) && + element_size != sizeof(int32_t) && + element_size != sizeof(int64_t)) { + // User will not hit this as this kernel is for fixed element size tensors only + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ", + element_size); + } + unsigned int num_elements_per_thread = 4 * sizeof(int) / static_cast(element_size); // int4 is used in the kernel to access data. dim3 block_size(static_cast(input_shape[3] / num_elements_per_thread), static_cast(input_shape[2])); - bool multiple_elements_per_thread_in_last_dim = true; - - if (block_size.x == 0) { - // Entering this means that input_shape[3] was less than num_elements_per_thread, - // hence have a thread process multiple elements in axis = 2 instead - block_size.x = static_cast(input_shape[3]); - block_size.y = static_cast(input_shape[2] / num_elements_per_thread); - multiple_elements_per_thread_in_last_dim = false; - } dim3 grid_size(static_cast(input_shape[1]), static_cast(input_shape[0])); - switch (element_size) { - case sizeof(int8_t): - Transpose4DKernel<<>>( - input_strides, input_data, - output_strides, output_data, num_elements_per_thread, multiple_elements_per_thread_in_last_dim, N / num_elements_per_thread); - break; - case sizeof(int16_t): - Transpose4DKernel<<>>( - input_strides, input_data, - output_strides, output_data, num_elements_per_thread, multiple_elements_per_thread_in_last_dim, N / num_elements_per_thread); - break; - case sizeof(int32_t): - Transpose4DKernel<<>>( - input_strides, input_data, - output_strides, output_data, num_elements_per_thread, multiple_elements_per_thread_in_last_dim, N / num_elements_per_thread); - break; - case sizeof(int64_t): - Transpose4DKernel<<>>( - input_strides, input_data, - output_strides, output_data, num_elements_per_thread, multiple_elements_per_thread_in_last_dim, N / num_elements_per_thread); - break; - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ", - element_size); + Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim<<>>( + input_strides, input_data, + output_strides, output_data, element_size, N / num_elements_per_thread); + + return Status::OK(); +} + +__global__ void Transpose4DKernelParallelizeOneElementPerThread( + const TArray input_strides, const int8_t* input_data, + const TArray output_strides, int8_t* output_data, + size_t element_size, + CUDA_LONG N) { + CUDA_LONG input_index = blockIdx.y * input_strides[0] + + blockIdx.x * input_strides[1] + + threadIdx.y * input_strides[2] + + threadIdx.x * input_strides[3]; + + CUDA_LONG output_index = blockIdx.y * output_strides[0] + + blockIdx.x * output_strides[1] + + threadIdx.y * output_strides[2] + + threadIdx.x * output_strides[3]; + + if (input_index < N && output_index < N) { + const int8_t* input_data_to_be_copied = input_data + (input_index * element_size); + int8_t* output_data_to_be_copied = output_data + (output_index * element_size); + + // copy over the bytes + memcpy(output_data_to_be_copied, input_data_to_be_copied, element_size); } +} + +bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop, + size_t element_size, + int32_t rank, + const std::vector& input_dims, + const std::vector& permutations) { + if (rank == 4) { + // The block size will be set based on the outer-most two dimensions of 4D tensor. + // the number threads per block will be calculated as below. + int64_t number_of_threads_per_block = input_dims[2] * input_dims[3]; + + if (number_of_threads_per_block <= prop.maxThreadsPerBlock && + number_of_threads_per_block >= prop.warpSize && + // num_threads_per_block must be a multiple of warp size (32) + ((number_of_threads_per_block & (prop.warpSize - 1)) == 0)) { + return true; + } + } + return false; +} + +Status Transpose4DParallelizeOneElementPerThread( + cudaStream_t stream, size_t element_size, + const TArray& input_shape, const TArray& input_strides, + const void* input_data, const TArray& output_strides, + void* output_data, int N) { + if (element_size != sizeof(int8_t) && + element_size != sizeof(int16_t) && + element_size != sizeof(int32_t) && + element_size != sizeof(int64_t)) { + // User will not hit this as this kernel is for fixed element size tensors only + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ", + element_size); + } + + dim3 block_size(static_cast(input_shape[3]), static_cast(input_shape[2])); + dim3 grid_size(static_cast(input_shape[1]), static_cast(input_shape[0])); + + Transpose4DKernelParallelizeOneElementPerThread<<>>( + input_strides, reinterpret_cast(input_data), + output_strides, reinterpret_cast(output_data), + element_size, N); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.h b/onnxruntime/core/providers/cuda/tensor/transpose_impl.h index 1a4d469776d54..a9184d2a16ab3 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.h @@ -11,13 +11,25 @@ namespace cuda { bool CanDoTranspose3D(int32_t rank, const std::vector& input_dims, const std::vector& permutations); Status Transpose3DImpl(cudaStream_t stream, size_t element_size, const TArray& input_shape, const TArray& input_strides, const void* input_data, void* output_data, int64_t N); -bool CanDoTranspose4D(const cudaDeviceProp& prop, - size_t element_size, - int32_t rank, - const std::vector& input_dims, - const std::vector& permutations); -Status Transpose4DImpl(cudaStream_t stream, size_t element_size, const TArray& input_shape, const TArray& input_strides, const void* input_data, - const TArray& output_strides, void* output_data, int N); + +bool CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim(const cudaDeviceProp& prop, + size_t element_size, + int32_t rank, + const std::vector& input_dims, + const std::vector& permutations); +Status Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim(cudaStream_t stream, size_t element_size, const TArray& input_shape, + const TArray& input_strides, const void* input_data, + const TArray& output_strides, void* output_data, int N); + +bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop, + size_t element_size, + int32_t rank, + const std::vector& input_dims, + const std::vector& permutations); +Status Transpose4DParallelizeOneElementPerThread(cudaStream_t stream, size_t element_size, const TArray& input_shape, + const TArray& input_strides, const void* input_data, + const TArray& output_strides, void* output_data, int N); + Status TransposeImpl(cudaStream_t stream, size_t element_size, int32_t shape_rank, const TArray& input_strides, const void* input_data, const TArray& fdm_output_strides, void* output_data, int N); } // namespace cuda diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index e968294cfdf68..515fa120c63fb 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -30,7 +30,7 @@ void TransposeTest(std::vector& input_shape, std::vector& input_vals, std::vector* p_perm, std::vector expected_shape, - std::vector& expected_vals, + std::initializer_list& expected_vals, bool is_tensorrt_supported = true, bool is_openvino_supported = true) { OpTester test("Transpose"); @@ -59,7 +59,7 @@ TEST(TransposeOpTest, TwoDimNoAttr) { 4.0f, 5.0f, 6.0f}; std::vector expected_shape({3, 2}); - std::vector expected_vals = { + auto expected_vals = { 1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f}; @@ -74,7 +74,7 @@ TEST(TransposeOpTest, TwoDimNoAttrStr) { "4", "5", "6"}; std::vector expected_shape({3, 2}); - std::vector expected_vals = { + std::initializer_list expected_vals = { "1", "4", "2", "5", "3", "6"}; @@ -90,9 +90,9 @@ TEST(TransposeOpTest, TwoDim) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::vector expected_vals = {1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f}; + auto expected_vals = {1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -104,9 +104,9 @@ TEST(TransposeOpTest, TwoDim_double) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::vector expected_vals = {1.0, 4.0, - 2.0, 5.0, - 3.0, 6.0}; + std::initializer_list expected_vals = {1.0, 4.0, + 2.0, 5.0, + 3.0, 6.0}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -118,9 +118,9 @@ TEST(TransposeOpTest, TwoDim_int32) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::vector expected_vals = {1, 4, - 2, 5, - 3, 6}; + std::initializer_list expected_vals = {1, 4, + 2, 5, + 3, 6}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -133,7 +133,7 @@ TEST(TransposeOpTest, TwoDim_int16) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::vector expected_vals = { + std::initializer_list expected_vals = { 1, 4, 2, 5, 3, 6}; @@ -149,7 +149,7 @@ TEST(TransposeOpTest, TwoDim_mlfloat16) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::vector expected_vals = + std::initializer_list expected_vals = {MLFloat16{static_cast(1)}, MLFloat16{static_cast(4)}, MLFloat16{static_cast(2)}, MLFloat16{static_cast(5)}, MLFloat16{static_cast(3)}, MLFloat16{static_cast(6)}}; @@ -164,9 +164,9 @@ TEST(TransposeOpTest, TwoDim_int8) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::vector expected_vals = {1, 4, - 2, 5, - 3, 6}; + std::initializer_list expected_vals = {1, 4, + 2, 5, + 3, 6}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); } @@ -179,7 +179,7 @@ TEST(TransposeOpTest, TwoDimStr) { std::vector perm = {1, 0}; std::vector expected_shape({3, 2}); - std::vector expected_vals = { + std::initializer_list expected_vals = { "1", "4", "2", "5", "3", "6"}; @@ -205,23 +205,22 @@ TEST(TransposeOpTest, ThreeDim) { std::vector perm = {0, 2, 1}; std::vector expected_shape({4, 3, 2}); - std::vector - expected_vals = { - 1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f, + auto expected_vals = { + 1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f, - 1.1f, 4.1f, - 2.1f, 5.1f, - 3.1f, 6.1f, + 1.1f, 4.1f, + 2.1f, 5.1f, + 3.1f, 6.1f, - 1.2f, 4.2f, - 2.2f, 5.2f, - 3.2f, 6.2f, + 1.2f, 4.2f, + 2.2f, 5.2f, + 3.2f, 6.2f, - 1.3f, 4.3f, - 2.3f, 5.3f, - 3.3f, 6.3f}; + 1.3f, 4.3f, + 2.3f, 5.3f, + 3.3f, 6.3f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); //TensorRT: illegal error } @@ -244,7 +243,7 @@ TEST(TransposeOpTest, ThreeDimSuffix) { std::vector perm = {1, 0, 2}; std::vector expected_shape({2, 4, 3}); - std::vector expected_vals = { + auto expected_vals = { 1.0f, 2.0f, 3.0f, 1.1f, 2.1f, 3.1f, 1.2f, 2.2f, 3.2f, @@ -275,7 +274,7 @@ TEST(TransposeOpTest, TransposeReshape) { std::vector perm = {1, 3, 2, 4, 0}; std::vector expected_shape({4, 1, 2, 3, 1}); - std::vector expected_vals = { + auto expected_vals = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, @@ -308,7 +307,7 @@ TEST(TransposeOpTest, ThreeDimStr) { std::vector perm = {0, 2, 1}; std::vector expected_shape({4, 3, 2}); - std::vector expected_vals = { + std::initializer_list expected_vals = { "1", "4", "2", "5", "3", "6", @@ -338,7 +337,7 @@ static void NumericNCHW2NHWC() { std::vector perm = {0, 2, 3, 1}; std::vector expected_shape({1, 2, 2, 3}); - std::vector expected_vals = { + std::initializer_list expected_vals = { 1, 5, 9, 2, 6, 10, 3, 7, 11, @@ -362,7 +361,7 @@ TEST(TransposeOpTest, NCHW2NHWCStr) { std::vector perm = {0, 2, 3, 1}; std::vector expected_shape({1, 2, 2, 3}); - std::vector expected_vals = { + std::initializer_list expected_vals = { "1", "5", "9", "2", "6", "10", "3", "7", "11", @@ -389,7 +388,7 @@ static void NumericNHWC2NCHW() { std::vector perm = {0, 3, 1, 2}; std::vector expected_shape({2, 2, 2, 2}); - std::vector expected_vals = { + std::initializer_list expected_vals = { 1, 3, 5, 7, @@ -422,7 +421,7 @@ TEST(TransposeOpTest, NHWC2NCHW_String) { std::vector perm = {0, 3, 1, 2}; std::vector expected_shape({1, 3, 2, 2}); - std::vector expected_vals = { + std::initializer_list expected_vals = { "1", "4", "7", "10", "2", "5", "8", "11", "3", "6", "9", "12"}; @@ -448,7 +447,7 @@ TEST(TransposeOpTest, SingleAxisMovingInwardsBlockCopy) { std::vector perm = {1, 2, 0, 3}; std::vector expected_shape({2, 2, 2, 2}); - std::vector expected_vals = { + std::initializer_list expected_vals = { 1, 2, 9, 10, @@ -472,17 +471,17 @@ TEST(TransposeOpTest, NDim) { 13.0f, 14.0f, 15.0f, 16.0f}; std::vector perm = {1, 0, 2, 3}; - std::vector expected_vals = {1.0f, 2.0f, 3.0f, 4.0f, - 9.0f, 10.0f, 11.0f, 12.0f, - 5.0f, 6.0f, 7.0f, 8.0f, - 13.0f, 14.0f, 15.0f, 16.0f}; + auto expected_vals = {1.0f, 2.0f, 3.0f, 4.0f, + 9.0f, 10.0f, 11.0f, 12.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 13.0f, 14.0f, 15.0f, 16.0f}; TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals); perm = {1, 0, 3, 2}; - std::vector expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f, - 9.0f, 11.0f, 10.0f, 12.0f, - 5.0f, 7.0f, 6.0f, 8.0f, - 13.0f, 15.0f, 14.0f, 16.0f}; + auto expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f, + 9.0f, 11.0f, 10.0f, 12.0f, + 5.0f, 7.0f, 6.0f, 8.0f, + 13.0f, 15.0f, 14.0f, 16.0f}; TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals2); } @@ -494,11 +493,11 @@ TEST(TransposeOpTest, DoTransposeImpl) { } std::vector perm = {2, 1, 0, 3}; std::vector expected_shape({1, 2, 5, 3}); - std::vector expected_vals = {0.0f, 1.0f, 2.0f, 6.0f, 7.0f, 8.0f, - 12.0f, 13.0f, 14.0f, 18.0f, 19.0f, 20.0f, - 24.0f, 25.0f, 26.0f, 3.0f, 4.0f, 5.0f, - 9.0f, 10.0f, 11.0f, 15.0f, 16.0f, 17.0f, - 21.0f, 22.0f, 23.0f, 27.0f, 28.0f, 29.0f}; + auto expected_vals = {0.0f, 1.0f, 2.0f, 6.0f, 7.0f, 8.0f, + 12.0f, 13.0f, 14.0f, 18.0f, 19.0f, 20.0f, + 24.0f, 25.0f, 26.0f, 3.0f, 4.0f, 5.0f, + 9.0f, 10.0f, 11.0f, 15.0f, 16.0f, 17.0f, + 21.0f, 22.0f, 23.0f, 27.0f, 28.0f, 29.0f}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -510,11 +509,11 @@ TEST(TransposeOpTest, DoTransposeImplString) { } std::vector perm = {2, 1, 0, 3}; std::vector expected_shape({1, 2, 5, 3}); - std::vector expected_vals = {"n0", "n1", "n2", "n6", "n7", "n8", - "n12", "n13", "n14", "n18", "n19", "n20", - "n24", "n25", "n26", "n3", "n4", "n5", - "n9", "n10", "n11", "n15", "n16", "n17", - "n21", "n22", "n23", "n27", "n28", "n29"}; + std::initializer_list expected_vals = {"n0", "n1", "n2", "n6", "n7", "n8", + "n12", "n13", "n14", "n18", "n19", "n20", + "n24", "n25", "n26", "n3", "n4", "n5", + "n9", "n10", "n11", "n15", "n16", "n17", + "n21", "n22", "n23", "n27", "n28", "n29"}; TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals); } @@ -527,10 +526,10 @@ TEST(TransposeOpTest, DoTransposeEltWise) { 13.0f, 14.0f, 15.0f, 16.0f}; std::vector perm = {1, 0, 3, 2}; - std::vector expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f, - 9.0f, 11.0f, 10.0f, 12.0f, - 5.0f, 7.0f, 6.0f, 8.0f, - 13.0f, 15.0f, 14.0f, 16.0f}; + auto expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f, + 9.0f, 11.0f, 10.0f, 12.0f, + 5.0f, 7.0f, 6.0f, 8.0f, + 13.0f, 15.0f, 14.0f, 16.0f}; TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals2); // Specific test which tests that function DoTransposeEltWise does not @@ -591,38 +590,35 @@ static void TestTranspose( test.CompareWithCPU(kGpuExecutionProvider, error_tolerance); } -TEST(TransposeOpTest, Transpose0213) { +TEST(TransposeOpTest, Transpose0213) { // Will trigger Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim() const std::vector X_dims{64, 128, 16, 64}; const std::vector perm{0, 2, 1, 3}; const std::vector Y_dims{64, 16, 128, 64}; TestTranspose(perm, X_dims, Y_dims); } -TEST(TransposeOpTest, Transpose0231) { +TEST(TransposeOpTest, Transpose0213_V2) { // Will trigger Transpose4DParallelizeOneElementPerThread() + const std::vector X_dims{64, 128, 64, 2}; + const std::vector perm{0, 2, 1, 3}; + const std::vector Y_dims{64, 64, 128, 2}; + TestTranspose(perm, X_dims, Y_dims); +} + +TEST(TransposeOpTest, Transpose0231) { // Will trigger Transpose3DImpl() because of "flattening" of dims 2 and 3 into one dim const std::vector X_dims{64, 128, 16, 64}; const std::vector perm{0, 2, 3, 1}; const std::vector Y_dims{64, 16, 64, 128}; TestTranspose(perm, X_dims, Y_dims); } -TEST(TransposeOpTest, Transpose0312) { +TEST(TransposeOpTest, Transpose0312) { // Will trigger Transpose3DImpl() because of "flattening" of dims 1 and 2 into one dim const std::vector X_dims{64, 16, 64, 128}; const std::vector perm{0, 3, 1, 2}; const std::vector Y_dims{64, 128, 16, 64}; TestTranspose(perm, X_dims, Y_dims); } -#endif - -// Test crafted with specific shape based on bug reported in GH issue : 7316 -TEST(TransposeOpTest, FloatDataPerm0213) { - std::vector input_shape({1, 1, 128, 1}); - std::vector input_vals(128, 1); - std::vector expected_shape({1, 128, 1, 1}); - - std::vector perm = {0, 2, 1, 3}; - TransposeTest(input_shape, input_vals, &perm, expected_shape, input_vals); -} +#endif } // namespace test } // namespace onnxruntime From 178b65a53bb02b0f7ff190bbf1091b1a0f0db212 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 24 May 2021 17:33:57 -0700 Subject: [PATCH 5/9] Fix AMD build --- .../core/providers/rocm/tensor/transpose.cc | 57 ++++++++++++------- 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/rocm/tensor/transpose.cc b/onnxruntime/core/providers/rocm/tensor/transpose.cc index 38b2a9cef1607..61e1147abe1a1 100644 --- a/onnxruntime/core/providers/rocm/tensor/transpose.cc +++ b/onnxruntime/core/providers/rocm/tensor/transpose.cc @@ -62,16 +62,16 @@ Status TransposeWithRocblas(hipStream_t stream, rocblas_handle rocblas_handle, c HipT* output_data = reinterpret_cast(output.MutableData()); ROCBLAS_RETURN_IF_ERROR( rocblasTransposeHelper(stream, - rocblas_handle, - rocblas_operation_transpose, rocblas_operation_transpose, M, N, - &one, - input_data, - N, - &zero, - input_data, - N, - output_data, - M)); + rocblas_handle, + rocblas_operation_transpose, rocblas_operation_transpose, M, N, + &one, + input_data, + N, + &zero, + input_data, + N, + output_data, + M)); return Status::OK(); } @@ -128,25 +128,25 @@ Status Transpose::DoTranspose(const hipDeviceProp_t& prop, new_permutations[j] -= 1; } } - for (auto j = i+1; j < new_rank; j++) { - new_permutations[j-1] = new_permutations[j]; + for (auto j = i + 1; j < new_rank; j++) { + new_permutations[j - 1] = new_permutations[j]; } // update input dims new_input_dims[prev] *= new_input_dims[curr]; new_input_dims[curr] = 1; - for (auto j = static_cast(curr+1); j < new_rank; j++) { - new_input_dims[j-1] = new_input_dims[j]; + for (auto j = static_cast(curr + 1); j < new_rank; j++) { + new_input_dims[j - 1] = new_input_dims[j]; } - new_input_dims[new_rank-1] = 1; + new_input_dims[new_rank - 1] = 1; // update output dims - new_output_dims[i-1] *= new_output_dims[i]; + new_output_dims[i - 1] *= new_output_dims[i]; new_output_dims[i] = 1; - for (auto j = i+1; j < new_rank; j++) { - new_output_dims[j-1] = new_output_dims[j]; + for (auto j = i + 1; j < new_rank; j++) { + new_output_dims[j - 1] = new_output_dims[j]; } - new_output_dims[new_rank-1] = 1; + new_output_dims[new_rank - 1] = 1; new_rank--; } @@ -166,13 +166,26 @@ Status Transpose::DoTranspose(const hipDeviceProp_t& prop, if (CanDoTranspose3D(new_rank, new_input_dims, new_permutations)) { return Transpose3DImpl(stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), output.MutableDataRaw(), output.Shape().Size()); - } else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) { + } else if (CanDoTranspose4DParallelizeMultipleElementsPerThreadInInnermostDim( + prop, element_size, new_rank, new_input_dims, new_permutations)) { TArray tmp_output_strides(new_rank); for (auto i = 0; i < new_rank; i++) { tmp_output_strides[i] = new_output_strides[new_permutations[i]]; } - return Transpose4DImpl(stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), - tmp_output_strides, output.MutableDataRaw(), output.Shape().Size()); + return Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim( + stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), + tmp_output_strides, output.MutableDataRaw(), gsl::narrow(output.Shape().Size())); + } else if (CanDoTranspose4DParallelizeOneElementPerThread( + prop, element_size, new_rank, new_input_dims, new_permutations)) { + // Trying to see if we can still do (best effort) more optimized transposing + // for the 4-D case before falling back to the generic case + TArray tmp_output_strides(new_rank); + for (auto i = 0; i < new_rank; i++) { + tmp_output_strides[i] = new_output_strides[new_permutations[i]]; + } + return Transpose4DParallelizeOneElementPerThread( + stream, element_size, input_shape, tmp_input_strides, input.DataRaw(), + tmp_output_strides, output.MutableDataRaw(), gsl::narrow(output.Shape().Size())); } // General cases From 480a489abc257bbaa85ece04ad228641d7271ea8 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 25 May 2021 00:06:07 -0700 Subject: [PATCH 6/9] Use local looping instead of memcpy within CUDA kernel --- onnxruntime/core/providers/cuda/tensor/transpose_impl.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index df1d90c40a3dc..579ebbfb2f910 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -177,7 +177,9 @@ __global__ void Transpose4DKernelParallelizeOneElementPerThread( int8_t* output_data_to_be_copied = output_data + (output_index * element_size); // copy over the bytes - memcpy(output_data_to_be_copied, input_data_to_be_copied, element_size); + for (size_t iter = 0; iter < element_size; ++iter) { + *output_data_to_be_copied++ = *input_data_to_be_copied++; + } } } From 526c103e5afcc01f3a3f686a0633c7a7816846e5 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 26 May 2021 21:31:02 -0700 Subject: [PATCH 7/9] PR feedback --- .../providers/cuda/tensor/transpose_impl.cu | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 579ebbfb2f910..625a2f0d7e6e7 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -79,6 +79,7 @@ Status Transpose3DImpl(cudaStream_t stream, size_t element_size, return Status::OK(); } +template __global__ void Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim( const TArray input_strides, const void* input_data, const TArray output_strides, void* output_data, @@ -137,22 +138,40 @@ Status Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim( const TArray& input_shape, const TArray& input_strides, const void* input_data, const TArray& output_strides, void* output_data, int N) { - if (element_size != sizeof(int8_t) && - element_size != sizeof(int16_t) && - element_size != sizeof(int32_t) && - element_size != sizeof(int64_t)) { - // User will not hit this as this kernel is for fixed element size tensors only - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ", - element_size); - } - unsigned int num_elements_per_thread = 4 * sizeof(int) / static_cast(element_size); // int4 is used in the kernel to access data. dim3 block_size(static_cast(input_shape[3] / num_elements_per_thread), static_cast(input_shape[2])); dim3 grid_size(static_cast(input_shape[1]), static_cast(input_shape[0])); - Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim<<>>( - input_strides, input_data, - output_strides, output_data, element_size, N / num_elements_per_thread); + switch (element_size) { + case sizeof(int8_t): + Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim + <<>>( + input_strides, input_data, + output_strides, output_data, element_size, N / num_elements_per_thread); + break; + case sizeof(int16_t): + Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim + <<>>( + input_strides, input_data, + output_strides, output_data, element_size, N / num_elements_per_thread); + break; + case sizeof(int32_t): + Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim + <<>>( + input_strides, input_data, + output_strides, output_data, element_size, N / num_elements_per_thread); + break; + case sizeof(int64_t): + Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim + <<>>( + input_strides, input_data, + output_strides, output_data, element_size, N / num_elements_per_thread); + break; + default: + // User will not hit this as this kernel is for fixed element size tensors only + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ", + element_size); + } return Status::OK(); } From 53486a32213579586cb48c243db22adb7b84e03a Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 26 May 2021 21:44:03 -0700 Subject: [PATCH 8/9] Fix --- onnxruntime/core/providers/cuda/tensor/transpose_impl.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 625a2f0d7e6e7..b2bf035e2501e 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -83,7 +83,6 @@ template __global__ void Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim( const TArray input_strides, const void* input_data, const TArray output_strides, void* output_data, - size_t element_size, CUDA_LONG N) { // output coordinates will be: blockIdx.y, blockIdx.x, threadIdx.y, threadIdx.x CUDA_LONG input_index = (blockIdx.y * input_strides[0] + From a554db11605a1eeda7b58cf035c604eac41f204a Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 26 May 2021 21:53:36 -0700 Subject: [PATCH 9/9] Fix --- onnxruntime/core/providers/cuda/tensor/transpose_impl.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index b2bf035e2501e..006dce292f141 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -146,25 +146,25 @@ Status Transpose4DParallelizeMultipleElementsPerThreadInInnermostDim( Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim <<>>( input_strides, input_data, - output_strides, output_data, element_size, N / num_elements_per_thread); + output_strides, output_data, N / num_elements_per_thread); break; case sizeof(int16_t): Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim <<>>( input_strides, input_data, - output_strides, output_data, element_size, N / num_elements_per_thread); + output_strides, output_data, N / num_elements_per_thread); break; case sizeof(int32_t): Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim <<>>( input_strides, input_data, - output_strides, output_data, element_size, N / num_elements_per_thread); + output_strides, output_data, N / num_elements_per_thread); break; case sizeof(int64_t): Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim <<>>( input_strides, input_data, - output_strides, output_data, element_size, N / num_elements_per_thread); + output_strides, output_data, N / num_elements_per_thread); break; default: // User will not hit this as this kernel is for fixed element size tensors only