From f6cf7b865a64f9ea8c7890750143995454a32ffa Mon Sep 17 00:00:00 2001 From: "jack.kirk" Date: Wed, 2 Mar 2022 11:45:40 +0000 Subject: [PATCH 1/6] Implemented fp19 mma using the natural storage type uint32_t. Signed-off-by: jack.kirk --- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 21 +++- .../matrix/matrix-nvptx-fp19-test.cpp | 112 ++++++++++++++++++ 2 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 5c6df9114b161..cce65f54c0999 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -81,6 +81,10 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2) __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2) __SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8) + // m16n16k8 fp19 +__SYCL_JOINT_MATRIX_OVERLOAD(uint32_t, a, 16, 8, int32_t, 4) +__SYCL_JOINT_MATRIX_OVERLOAD(uint32_t, b, 8, 16, int32_t, 4) + #undef __SYCL_JOINT_MATRIX_OVERLOAD } // namespace experimental::matrix @@ -271,7 +275,17 @@ struct joint_matrix_load_impl< __dmma_m8n8k4_ld_c(res.data, src.get(), stride, get_layout_id()); } - } + } else if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (NumRows == 16 && NumCols == 8) { + __mma_tf32_m16n16k8_ld_a(res.data, tileptr, stride, + get_layout_id()); + } + else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_tf32_m16n16k8_ld_b(res.data, tileptr, stride, + get_layout_id()); + } + } } }; @@ -495,7 +509,10 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } - } else if constexpr (std::is_same::value) { + } else if constexpr (M == 16 && N == 16 && K == 8) { + __mma_tf32_m16n16k8_mma_f32(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { __dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, get_layout_pair_id(), 0); } diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp new file mode 100644 index 0000000000000..4ce05be7a8b1a --- /dev/null +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp @@ -0,0 +1,112 @@ +// REQUIRES: cuda + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +// IMPORTANT: before updating sm version support beyond sm_86 read the following NOTE! + +// NOTE: Technically the 'wrong' ptx instruction is called by joint_matrix_load/joint_matrix_store in this case: +// notice that the load and store instructions use shape m16n16k16, rather than the correct shape m16n16k8. +// The 'wrong' ptx instruction is used because it returns the correct SASS instructions for all existing supported sm versions: +// sm_80 and sm_86. The Apparent reason for this ptx instruction redundancy is due to the ptx naming convention for the mnk shape triple; +// however we cannot in principle a priori know that future sm versions will behave in the same way and that this redundancy will remain. +// This should be validated before supporting any sm versions beyond sm_86. +// The reason that we choose to use the m16n16k16 instruction is that it allows the significant advantage of being able +// to use a portable interface across Intel and Nvidia backends. + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +// M, N, K define the sizes of dimensions of the three matrix types (a, b, +// accumulator) used per subgroup operation. +constexpr int M = 16; // number of rows of accumulator, + // number of cols of b. +constexpr int N = 16; // number of cols of accumulator, + // number of rows of a. +constexpr int K = 8; // number of cols of a/number of rows of b. + +uint32_t A[M * K]; +uint32_t B[K * N]; +float C[M * N]; +float D[M * N]; + +int main() { + + buffer bufA(A, range<1>(M * K)); + buffer bufB(B, range<1>(K * N)); + buffer bufC(C, range<1>(M * N)); + buffer bufD(D, range<1>(M * N)); + + queue q; + + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), N); + //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), K); + //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), N); + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 %10, i32 %11, i32 %12, i32 %13, i32 %15, i32 %16, i32 %17, i32 %18, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + //CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %20, float %21, float %22, float %23, float %24, float %25, float %26, float %27, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), N); + }); + }); + + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), N); + //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), K); + //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), N); + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 %10, i32 %11, i32 %12, i32 %13, i32 %15, i32 %16, i32 %17, i32 %18, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + //CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %20, float %21, float %22, float %23, float %24, float %25, float %26, float %27, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), N); + }); + }); + + return 0; +}; From 35302b5b1d22b11d96dcd1c0021eae7d99ef4ea0 Mon Sep 17 00:00:00 2001 From: "jack.kirk" Date: Wed, 2 Mar 2022 12:37:42 +0000 Subject: [PATCH 2/6] format --- .../ext/oneapi/matrix/matrix-tensorcore.hpp | 15 +++++++------ .../matrix/matrix-nvptx-fp19-test.cpp | 21 ++++++++++++------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index cce65f54c0999..ba23264b87ba4 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -276,16 +276,15 @@ struct joint_matrix_load_impl< get_layout_id()); } } else if constexpr (std::is_same::value) { - int32_t *tileptr = reinterpret_cast(src.get()); - if constexpr (NumRows == 16 && NumCols == 8) { + int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (NumRows == 16 && NumCols == 8) { __mma_tf32_m16n16k8_ld_a(res.data, tileptr, stride, get_layout_id()); - } - else if constexpr (NumRows == 8 && NumCols == 16) { + } else if constexpr (NumRows == 8 && NumCols == 16) { __mma_tf32_m16n16k8_ld_b(res.data, tileptr, stride, get_layout_id()); - } - } + } + } } }; @@ -509,10 +508,10 @@ struct joint_matrix_mad_impl< get_layout_pair_id(), 0); } } - } else if constexpr (M == 16 && N == 16 && K == 8) { + } else if constexpr (M == 16 && N == 16 && K == 8) { __mma_tf32_m16n16k8_mma_f32(D.data, A.data, B.data, C.data, get_layout_pair_id(), 0); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { __dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, get_layout_pair_id(), 0); } diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp index 4ce05be7a8b1a..ac17e37601c80 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp @@ -4,14 +4,19 @@ // IMPORTANT: before updating sm version support beyond sm_86 read the following NOTE! -// NOTE: Technically the 'wrong' ptx instruction is called by joint_matrix_load/joint_matrix_store in this case: -// notice that the load and store instructions use shape m16n16k16, rather than the correct shape m16n16k8. -// The 'wrong' ptx instruction is used because it returns the correct SASS instructions for all existing supported sm versions: -// sm_80 and sm_86. The Apparent reason for this ptx instruction redundancy is due to the ptx naming convention for the mnk shape triple; -// however we cannot in principle a priori know that future sm versions will behave in the same way and that this redundancy will remain. -// This should be validated before supporting any sm versions beyond sm_86. -// The reason that we choose to use the m16n16k16 instruction is that it allows the significant advantage of being able -// to use a portable interface across Intel and Nvidia backends. +// NOTE: Technically the 'wrong' ptx instruction is called by +// joint_matrix_load/joint_matrix_store in this case: notice that the load and +// store instructions use shape m16n16k16, rather than the correct shape +// m16n16k8. The 'wrong' ptx instruction is used because it returns the correct +// SASS instructions for all existing supported sm versions: sm_80 and sm_86. +// The reason for this ptx instruction redundancy is due to the ptx naming +// convention for the mnk shape triple; however we cannot in principle a priori +// know that future sm versions will behave in the same way and that this +// redundancy will continue as future architecture is released. This should be +// validated before supporting any sm versions beyond sm_86. The reason that we +// choose to use the m16n16k16 instruction is that it allows the significant +// advantage of being able to use a portable interface across Intel and Nvidia +// backends. #include From 712af980e261c42330603ac2ca410ebaa0d19c75 Mon Sep 17 00:00:00 2001 From: "jack.kirk" Date: Wed, 2 Mar 2022 13:15:57 +0000 Subject: [PATCH 3/6] format --- .../include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp | 6 +++--- .../check_device_code/matrix/matrix-nvptx-fp19-test.cpp | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index ba23264b87ba4..012d3eb5006f0 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -81,7 +81,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2) __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2) __SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8) - // m16n16k8 fp19 +// m16n16k8 fp19 __SYCL_JOINT_MATRIX_OVERLOAD(uint32_t, a, 16, 8, int32_t, 4) __SYCL_JOINT_MATRIX_OVERLOAD(uint32_t, b, 8, 16, int32_t, 4) @@ -278,10 +278,10 @@ struct joint_matrix_load_impl< } else if constexpr (std::is_same::value) { int32_t *tileptr = reinterpret_cast(src.get()); if constexpr (NumRows == 16 && NumCols == 8) { - __mma_tf32_m16n16k8_ld_a(res.data, tileptr, stride, + __mma_tf32_m16n16k8_ld_a(res.data, tileptr, stride, get_layout_id()); } else if constexpr (NumRows == 8 && NumCols == 16) { - __mma_tf32_m16n16k8_ld_b(res.data, tileptr, stride, + __mma_tf32_m16n16k8_ld_b(res.data, tileptr, stride, get_layout_id()); } } diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp index ac17e37601c80..56deb5c52b6fe 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp @@ -2,7 +2,8 @@ // RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s -// IMPORTANT: before updating sm version support beyond sm_86 read the following NOTE! +// IMPORTANT: before updating sm version support beyond sm_86 read the following +// NOTE! // NOTE: Technically the 'wrong' ptx instruction is called by // joint_matrix_load/joint_matrix_store in this case: notice that the load and @@ -26,10 +27,10 @@ using namespace sycl::ext::oneapi::experimental::matrix; // M, N, K define the sizes of dimensions of the three matrix types (a, b, // accumulator) used per subgroup operation. constexpr int M = 16; // number of rows of accumulator, - // number of cols of b. + // number of cols of b. constexpr int N = 16; // number of cols of accumulator, - // number of rows of a. -constexpr int K = 8; // number of cols of a/number of rows of b. + // number of rows of a. +constexpr int K = 8; // number of cols of a/number of rows of b. uint32_t A[M * K]; uint32_t B[K * N]; From 35306433b318befdb407e846be5a7a8eaac2a17a Mon Sep 17 00:00:00 2001 From: "jack.kirk" Date: Wed, 2 Mar 2022 13:46:01 +0000 Subject: [PATCH 4/6] format --- sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp index 56deb5c52b6fe..4a3962c542ee5 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp @@ -66,7 +66,7 @@ int main() { joint_matrix sub_b; - + //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} joint_matrix_load(sg, sub_c, accC.get_pointer(), N); //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 8) #{{.*}} From fa67ff986453a9959ccc244c72911f5d3179bab4 Mon Sep 17 00:00:00 2001 From: "jack.kirk" Date: Thu, 3 Mar 2022 15:48:22 +0000 Subject: [PATCH 5/6] added comment relating uint32_t to fp19 --- sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp index 4a3962c542ee5..c183ae7ba17aa 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp @@ -32,6 +32,7 @@ constexpr int N = 16; // number of cols of accumulator, // number of rows of a. constexpr int K = 8; // number of cols of a/number of rows of b. +// uint32_t is used in this test as the storage type for fp19 uint32_t A[M * K]; uint32_t B[K * N]; float C[M * N]; From bfc68d22e3ef8166aa1d0476a25baf9ce6f59c15 Mon Sep 17 00:00:00 2001 From: "jack.kirk" Date: Thu, 10 Mar 2022 21:34:23 +0000 Subject: [PATCH 6/6] fp19 comments ->tf32 --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp | 2 +- .../{matrix-nvptx-fp19-test.cpp => matrix-nvptx-tf32-test.cpp} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename sycl/test/check_device_code/matrix/{matrix-nvptx-fp19-test.cpp => matrix-nvptx-tf32-test.cpp} (99%) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 012d3eb5006f0..df3b6e97d258e 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -81,7 +81,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2) __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2) __SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8) -// m16n16k8 fp19 +// m16n16k8 tf32 __SYCL_JOINT_MATRIX_OVERLOAD(uint32_t, a, 16, 8, int32_t, 4) __SYCL_JOINT_MATRIX_OVERLOAD(uint32_t, b, 8, 16, int32_t, 4) diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp similarity index 99% rename from sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp rename to sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp index c183ae7ba17aa..709005a45b160 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-fp19-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-tf32-test.cpp @@ -32,7 +32,7 @@ constexpr int N = 16; // number of cols of accumulator, // number of rows of a. constexpr int K = 8; // number of cols of a/number of rows of b. -// uint32_t is used in this test as the storage type for fp19 +// uint32_t is used in this test as the storage type for tf32 uint32_t A[M * K]; uint32_t B[K * N]; float C[M * N];