diff --git a/sycl/include/sycl/accessor.hpp b/sycl/include/sycl/accessor.hpp index c134c5cb0814e..439248a94f80c 100644 --- a/sycl/include/sycl/accessor.hpp +++ b/sycl/include/sycl/accessor.hpp @@ -2124,23 +2124,12 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor : template > -#if SYCL_LANGUAGE_VERSION >= 202001 - std::add_pointer_t get_pointer() const noexcept -#else - DataT *get_pointer() const -#endif - { + (AccessTarget_ == access::target::host_task) || + (AccessTarget_ == access::target::device)>> + std::add_pointer_t get_pointer() const noexcept { return getPointerAdjusted(); } - template < - access::target AccessTarget_ = AccessTarget, - typename = std::enable_if_t> - global_ptr get_pointer() const { - return global_ptr(getPointerAdjusted()); - } - template > diff --git a/sycl/include/sycl/ext/intel/esimd/detail/util.hpp b/sycl/include/sycl/ext/intel/esimd/detail/util.hpp index 29c37ff65d43f..d34039cbeff10 100644 --- a/sycl/include/sycl/ext/intel/esimd/detail/util.hpp +++ b/sycl/include/sycl/ext/intel/esimd/detail/util.hpp @@ -182,9 +182,15 @@ template class ForHelper { /// Returns the address referenced by the accessor \p Acc and /// the byte offset \p Offset. template -T *accessorToPointer(AccessorTy Acc, OffsetTy Offset = 0) { - auto BytePtr = reinterpret_cast(Acc.get_pointer().get()) + Offset; - return reinterpret_cast(BytePtr); +auto accessorToPointer(AccessorTy Acc, OffsetTy Offset = 0) { + using QualCharPtrType = + std::conditional_t, + const char *, char *>; + using QualTPtrType = + std::conditional_t, + const T *, T *>; + auto BytePtr = reinterpret_cast(Acc.get_pointer()) + Offset; + return reinterpret_cast(BytePtr); } #endif // __ESIMD_FORCE_STATELESS_MEM diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp index dad883309b01a..94e50ec8925d6 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp @@ -776,8 +776,7 @@ __ESIMD_API std::enable_if_t, lsc_gather(AccessorTy acc, __ESIMD_NS::simd offsets, __ESIMD_NS::simd_mask pred = 1) { #ifdef __ESIMD_FORCE_STATELESS_MEM - return lsc_gather(acc.get_pointer().get(), offsets, - pred); + return lsc_gather(acc.get_pointer(), offsets, pred); #else detail::check_lsc_vector_size(); detail::check_lsc_data_size(); @@ -829,8 +828,8 @@ lsc_gather(AccessorTy acc, __ESIMD_NS::simd offsets, __ESIMD_NS::simd_mask pred, __ESIMD_NS::simd old_values) { #ifdef __ESIMD_FORCE_STATELESS_MEM - return lsc_gather(acc.get_pointer().get(), offsets, - pred, old_values); + return lsc_gather(acc.get_pointer(), offsets, pred, + old_values); #else detail::check_lsc_vector_size(); detail::check_lsc_data_size(); diff --git a/sycl/include/sycl/group_algorithm.hpp b/sycl/include/sycl/group_algorithm.hpp index 424a7800d7bb7..367e5e649c642 100644 --- a/sycl/include/sycl/group_algorithm.hpp +++ b/sycl/include/sycl/group_algorithm.hpp @@ -96,7 +96,10 @@ using native_op_list = template struct is_native_op { static constexpr bool value = - is_contained>::value || + is_contained>>::value || + is_contained>>::value || is_contained>::value; }; @@ -123,9 +126,9 @@ struct is_complex // ---- is_arithmetic_or_complex template -using is_arithmetic_or_complex = - std::integral_constant::value || - sycl::detail::is_arithmetic::value>; +using is_arithmetic_or_complex = std::integral_constant< + bool, sycl::detail::is_complex>::value || + sycl::detail::is_arithmetic::value>; template struct is_vector_arithmetic_or_complex diff --git a/sycl/include/sycl/multi_ptr.hpp b/sycl/include/sycl/multi_ptr.hpp index 175dc7e1489a9..0ceb59d41b0c7 100644 --- a/sycl/include/sycl/multi_ptr.hpp +++ b/sycl/include/sycl/multi_ptr.hpp @@ -113,6 +113,16 @@ class multi_ptr { : m_Pointer(ptr) {} multi_ptr(std::nullptr_t) : m_Pointer(nullptr) {} + // Implicit conversion from multi_ptr to multi_ptr + template , + typename = typename std::enable_if_t< + std::is_const_v && + std::is_same_v>>> + explicit multi_ptr( + multi_ptr MPtr) + : m_Pointer(MPtr.get_decorated()) {} + // Only if Space is in // {global_space, ext_intel_global_device_space, generic_space} template < @@ -126,8 +136,7 @@ class multi_ptr { multi_ptr(accessor Accessor) - : multi_ptr( - detail::cast_AS(Accessor.get_pointer().get())) {} + : multi_ptr(Accessor.template get_multi_ptr()) {} // Only if Space == local_space || generic_space template > multi_ptr(local_accessor Accessor) - : m_Pointer(detail::cast_AS(Accessor.get_pointer())) {} + : multi_ptr(Accessor.template get_multi_ptr()) {} // The following constructors are necessary to create multi_ptr from accessor. @@ -177,8 +186,8 @@ class multi_ptr { multi_ptr(accessor, Dimensions, Mode, access::target::device, isPlaceholder, PropertyListT> Accessor) - : multi_ptr( - detail::cast_AS(Accessor.get_pointer().get())) {} + : m_Pointer(Accessor.template get_multi_ptr() + .get_decorated()) {} // Only if Space == local_space || generic_space and element type is const template , Dimensions> Accessor) - : m_Pointer(detail::cast_AS(Accessor.get_pointer())) {} + : multi_ptr(Accessor.template get_multi_ptr()) {} // Assignment and access operators multi_ptr &operator=(const multi_ptr &) = default; @@ -441,8 +450,7 @@ class multi_ptr { multi_ptr(accessor Accessor) - : multi_ptr( - detail::cast_AS(Accessor.get_pointer().get())) {} + : multi_ptr(Accessor.template get_multi_ptr()) {} // Only if Space == local_space template < @@ -463,7 +471,7 @@ class multi_ptr { typename = typename std::enable_if_t< RelaySpace == Space && Space == access::address_space::local_space>> multi_ptr(local_accessor Accessor) - : m_Pointer(detail::cast_AS(Accessor.get_pointer())) {} + : multi_ptr(Accessor.template get_multi_ptr()) {} // Assignment operators multi_ptr &operator=(const multi_ptr &) = default; @@ -567,8 +575,7 @@ class multi_ptr { multi_ptr(accessor Accessor) - : multi_ptr( - detail::cast_AS(Accessor.get_pointer().get())) {} + : multi_ptr(Accessor.template get_multi_ptr()) {} // Only if Space == local_space template < @@ -589,7 +596,7 @@ class multi_ptr { typename = typename std::enable_if_t< RelaySpace == Space && Space == access::address_space::local_space>> multi_ptr(local_accessor Accessor) - : m_Pointer(detail::cast_AS(Accessor.get_pointer())) {} + : multi_ptr(Accessor.template get_multi_ptr()) {} // Assignment operators multi_ptr &operator=(const multi_ptr &) = default; @@ -760,7 +767,7 @@ class multi_ptr { multi_ptr(accessor Accessor) { - m_Pointer = detail::cast_AS(Accessor.get_pointer().get()); + m_Pointer = detail::cast_AS(Accessor.get_pointer()); } // Only if Space == local_space || generic_space diff --git a/sycl/test-e2e/Basic/multi_ptr.hpp b/sycl/test-e2e/Basic/multi_ptr.hpp index 502a1180f4316..ebd8363ecff21 100644 --- a/sycl/test-e2e/Basic/multi_ptr.hpp +++ b/sycl/test-e2e/Basic/multi_ptr.hpp @@ -31,7 +31,7 @@ template struct point { }; template -void innerFunc(id<1> wiID, global_ptr ptr_1, +void innerFunc(id<1> wiID, global_ptr ptr_1, global_ptr ptr_2, global_ptr ptr_3, global_ptr ptr_4, @@ -110,9 +110,8 @@ template void testMultPtr() { private_data[i] = 0; localAccessor[wiID.get_local_id()] = 0; - auto ptr_1 = - multi_ptr( - accessorData_1); + auto ptr_1 = multi_ptr(accessorData_1); auto ptr_2 = multi_ptr( accessorData_2); @@ -136,11 +135,13 @@ template void testMultPtr() { // Construct extension pointer from accessors. auto dev_ptr = - multi_ptr(accessorData_1); - static_assert(std::is_same_v, - decltype(dev_ptr)>, - "Incorrect type for dev_ptr."); + static_assert( + std::is_same_v, + decltype(dev_ptr)>, + "Incorrect type for dev_ptr."); // General conversions in multi_ptr class T *RawPtr = nullptr; @@ -148,7 +149,7 @@ template void testMultPtr() { address_space_cast(RawPtr); - global_ptr ptr_7(accessorData_1); + global_ptr ptr_7(accessorData_1); global_ptr ptr_8 = address_space_cast private_val = 0; auto ptr_1 = - multi_ptr, access::address_space::global_space, + multi_ptr, access::address_space::global_space, IsDecorated>(accessorData_1); auto ptr_2 = multi_ptr, access::address_space::local_space, IsDecorated>(accessorData_2); auto ptr_3 = - multi_ptr, + multi_ptr, access::address_space::ext_intel_global_device_space, IsDecorated>(accessorData_3); auto ptr_4 = diff --git a/sycl/test-e2e/Basic/multi_ptr_legacy.hpp b/sycl/test-e2e/Basic/multi_ptr_legacy.hpp index b6161d3390a61..7303535c1b624 100644 --- a/sycl/test-e2e/Basic/multi_ptr_legacy.hpp +++ b/sycl/test-e2e/Basic/multi_ptr_legacy.hpp @@ -8,7 +8,7 @@ #include #include -#include +#include #include using namespace sycl; @@ -30,7 +30,7 @@ template struct point { }; template -void innerFunc(id<1> wiID, global_ptr ptr_1, global_ptr ptr_2, +void innerFunc(id<1> wiID, global_ptr ptr_1, global_ptr ptr_2, local_ptr local_ptr) { T t = ptr_1[wiID.get(0)]; local_ptr[wiID.get(0)] = t; @@ -64,31 +64,33 @@ template void testMultPtr() { cgh.parallel_for>( nd_range<1>{10, 10}, [=](nd_item<1> wiID) { - auto ptr_1 = make_ptr( - accessorData_1.get_pointer()); + accessorData_1 + .template get_multi_ptr()); auto ptr_2 = make_ptr( - accessorData_2.get_pointer()); + accessorData_2 + .template get_multi_ptr()); auto local_ptr = make_ptr( localAccessor.get_pointer()); // Construct extension pointer from accessors. auto dev_ptr = - multi_ptr( accessorData_1); - static_assert( - std::is_same_v, decltype(dev_ptr)>, - "Incorrect type for dev_ptr."); + static_assert(std::is_same_v, + decltype(dev_ptr)>, + "Incorrect type for dev_ptr."); // General conversions in multi_ptr class T *RawPtr = nullptr; global_ptr ptr_4(RawPtr); ptr_4 = RawPtr; - global_ptr ptr_5(accessorData_1); + global_ptr ptr_5(accessorData_1); global_ptr ptr_6((void *)RawPtr); @@ -144,9 +146,11 @@ template void testMultPtrArrowOperator() { cgh.parallel_for>( sycl::nd_range<1>{1, 1}, [=](sycl::nd_item<1>) { - auto ptr_1 = make_ptr, access::address_space::global_space, - access::decorated::legacy>( - accessorData_1.get_pointer()); + auto ptr_1 = + make_ptr, access::address_space::global_space, + access::decorated::legacy>( + accessorData_1.template get_multi_ptr< + sycl::access::decorated::legacy>()); auto ptr_2 = make_ptr, access::address_space::constant_space, access::decorated::legacy>( @@ -155,7 +159,7 @@ template void testMultPtrArrowOperator() { access::decorated::legacy>( accessorData_3.get_pointer()); auto ptr_4 = - make_ptr, + make_ptr, access::address_space::ext_intel_global_device_space, access::decorated::legacy>( accessorData_4.get_pointer()); diff --git a/sycl/test-e2e/GroupAlgorithm/SYCL2020/all_of.cpp b/sycl/test-e2e/GroupAlgorithm/SYCL2020/all_of.cpp index 0e3cdc9b2686a..9f3e14d15500f 100644 --- a/sycl/test-e2e/GroupAlgorithm/SYCL2020/all_of.cpp +++ b/sycl/test-e2e/GroupAlgorithm/SYCL2020/all_of.cpp @@ -33,7 +33,9 @@ void test(queue q, InputContainer input, OutputContainer output, int lid = it.get_local_id(0); out[0] = all_of_group(g, pred(in[lid])); out[1] = all_of_group(g, in[lid], pred); - out[2] = joint_all_of(g, in.get_pointer(), in.get_pointer() + N, pred); + out[2] = joint_all_of( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, pred); }); }); } diff --git a/sycl/test-e2e/GroupAlgorithm/SYCL2020/any_of.cpp b/sycl/test-e2e/GroupAlgorithm/SYCL2020/any_of.cpp index 08d82bed93b47..5ca2f9802d10b 100644 --- a/sycl/test-e2e/GroupAlgorithm/SYCL2020/any_of.cpp +++ b/sycl/test-e2e/GroupAlgorithm/SYCL2020/any_of.cpp @@ -41,7 +41,9 @@ void test(queue q, InputContainer input, OutputContainer output, int lid = it.get_local_id(0); out[0] = any_of_group(g, pred(in[lid])); out[1] = any_of_group(g, in[lid], pred); - out[2] = joint_any_of(g, in.get_pointer(), in.get_pointer() + N, pred); + out[2] = joint_any_of( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, pred); }); }); } diff --git a/sycl/test-e2e/GroupAlgorithm/SYCL2020/exclusive_scan.cpp b/sycl/test-e2e/GroupAlgorithm/SYCL2020/exclusive_scan.cpp index 2723acbf4616a..fed771d74c208 100644 --- a/sycl/test-e2e/GroupAlgorithm/SYCL2020/exclusive_scan.cpp +++ b/sycl/test-e2e/GroupAlgorithm/SYCL2020/exclusive_scan.cpp @@ -92,8 +92,10 @@ void test(queue q, InputContainer input, OutputContainer output, accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; cgh.parallel_for(nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); - joint_exclusive_scan(g, in.get_pointer(), in.get_pointer() + N, - out.get_pointer(), binary_op); + joint_exclusive_scan( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, + out.template get_multi_ptr(), binary_op); }); }); } @@ -109,8 +111,11 @@ void test(queue q, InputContainer input, OutputContainer output, accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; cgh.parallel_for(nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); - joint_exclusive_scan(g, in.get_pointer(), in.get_pointer() + N, - out.get_pointer(), init, binary_op); + joint_exclusive_scan( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, + out.template get_multi_ptr(), init, + binary_op); }); }); } diff --git a/sycl/test-e2e/GroupAlgorithm/SYCL2020/inclusive_scan.cpp b/sycl/test-e2e/GroupAlgorithm/SYCL2020/inclusive_scan.cpp index bf3a859bad645..b4e4999338ab7 100644 --- a/sycl/test-e2e/GroupAlgorithm/SYCL2020/inclusive_scan.cpp +++ b/sycl/test-e2e/GroupAlgorithm/SYCL2020/inclusive_scan.cpp @@ -92,8 +92,10 @@ void test(queue q, InputContainer input, OutputContainer output, accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; cgh.parallel_for(nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); - joint_inclusive_scan(g, in.get_pointer(), in.get_pointer() + N, - out.get_pointer(), binary_op); + joint_inclusive_scan( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, + out.template get_multi_ptr(), binary_op); }); }); } @@ -109,8 +111,11 @@ void test(queue q, InputContainer input, OutputContainer output, accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; cgh.parallel_for(nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); - joint_inclusive_scan(g, in.get_pointer(), in.get_pointer() + N, - out.get_pointer(), binary_op, init); + joint_inclusive_scan( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, + out.template get_multi_ptr(), binary_op, + init); }); }); } diff --git a/sycl/test-e2e/GroupAlgorithm/SYCL2020/none_of.cpp b/sycl/test-e2e/GroupAlgorithm/SYCL2020/none_of.cpp index 24a07354b981f..5f1f003961716 100644 --- a/sycl/test-e2e/GroupAlgorithm/SYCL2020/none_of.cpp +++ b/sycl/test-e2e/GroupAlgorithm/SYCL2020/none_of.cpp @@ -39,7 +39,9 @@ void test(queue q, InputContainer input, OutputContainer output, int lid = it.get_local_id(0); out[0] = none_of_group(g, pred(in[lid])); out[1] = none_of_group(g, in[lid], pred); - out[2] = joint_none_of(g, in.get_pointer(), in.get_pointer() + N, pred); + out[2] = joint_none_of( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, pred); }); }); } diff --git a/sycl/test-e2e/GroupAlgorithm/SYCL2020/reduce.cpp b/sycl/test-e2e/GroupAlgorithm/SYCL2020/reduce.cpp index 28cfb9d796025..bfdd61f2448b6 100644 --- a/sycl/test-e2e/GroupAlgorithm/SYCL2020/reduce.cpp +++ b/sycl/test-e2e/GroupAlgorithm/SYCL2020/reduce.cpp @@ -37,10 +37,14 @@ void test(queue q, InputContainer input, OutputContainer output, int lid = it.get_local_id(0); out[0] = reduce_over_group(g, in[lid], binary_op); out[1] = reduce_over_group(g, in[lid], init, binary_op); - out[2] = joint_reduce(g, in.get_pointer(), in.get_pointer() + N, - binary_op); - out[3] = joint_reduce(g, in.get_pointer(), in.get_pointer() + N, - init, binary_op); + out[2] = joint_reduce( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, + binary_op); + out[3] = joint_reduce( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, init, + binary_op); }); }); } diff --git a/sycl/test-e2e/GroupAlgorithm/exclusive_scan_sycl2020.cpp b/sycl/test-e2e/GroupAlgorithm/exclusive_scan_sycl2020.cpp index 3da08c1c40f71..187b4c820f1c0 100644 --- a/sycl/test-e2e/GroupAlgorithm/exclusive_scan_sycl2020.cpp +++ b/sycl/test-e2e/GroupAlgorithm/exclusive_scan_sycl2020.cpp @@ -108,8 +108,10 @@ void test(queue q, InputContainer input, OutputContainer output, accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; cgh.parallel_for(nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); - joint_exclusive_scan(g, in.get_pointer(), in.get_pointer() + N, - out.get_pointer(), binary_op); + joint_exclusive_scan( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, + out.template get_multi_ptr(), binary_op); }); }); } @@ -133,8 +135,11 @@ void test(queue q, InputContainer input, OutputContainer output, accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; cgh.parallel_for(nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); - joint_exclusive_scan(g, in.get_pointer(), in.get_pointer() + N, - out.get_pointer(), init, binary_op); + joint_exclusive_scan( + g, in.template get_multi_ptr(), + in.template get_multi_ptr() + N, + out.template get_multi_ptr(), init, + binary_op); }); }); } diff --git a/sycl/test-e2e/GroupAlgorithm/inclusive_scan_sycl2020.cpp b/sycl/test-e2e/GroupAlgorithm/inclusive_scan_sycl2020.cpp index 065aaabcd4b9d..c75012dea2492 100644 --- a/sycl/test-e2e/GroupAlgorithm/inclusive_scan_sycl2020.cpp +++ b/sycl/test-e2e/GroupAlgorithm/inclusive_scan_sycl2020.cpp @@ -108,8 +108,9 @@ void test(queue q, InputContainer input, OutputContainer output, accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; cgh.parallel_for(nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); - joint_inclusive_scan(g, in.get_pointer(), in.get_pointer() + N, - out.get_pointer(), binary_op); + joint_inclusive_scan( + g, global_ptr(in), global_ptr(in) + N, + out.template get_multi_ptr(), binary_op); }); }); } @@ -133,8 +134,10 @@ void test(queue q, InputContainer input, OutputContainer output, accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; cgh.parallel_for(nd_range<1>(G, G), [=](nd_item<1> it) { group<1> g = it.get_group(); - joint_inclusive_scan(g, in.get_pointer(), in.get_pointer() + N, - out.get_pointer(), binary_op, init); + joint_inclusive_scan( + g, global_ptr(in), global_ptr(in) + N, + out.template get_multi_ptr(), binary_op, + init); }); }); } diff --git a/sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp b/sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp index 2baa0b3a12dd5..5e91c050c78e3 100644 --- a/sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp +++ b/sycl/test-e2e/GroupAlgorithm/reduce_sycl2020.cpp @@ -36,14 +36,16 @@ void test(queue q, InputContainer input, OutputContainer output, int lid = it.get_local_id(0); out[0] = reduce_over_group(g, in[lid], binary_op); out[1] = reduce_over_group(g, in[lid], init, binary_op); - out[2] = joint_reduce(g, in.get_pointer(), in.get_pointer() + N, - binary_op); - out[3] = joint_reduce(g, in.get_pointer(), in.get_pointer() + N, - init, binary_op); - out[4] = joint_reduce(sg, in.get_pointer(), in.get_pointer() + N, - binary_op); - out[5] = joint_reduce(sg, in.get_pointer(), in.get_pointer() + N, - init, binary_op); + out[2] = joint_reduce(g, global_ptr(in), + global_ptr(in) + N, binary_op); + out[3] = + joint_reduce(g, global_ptr(in), + global_ptr(in) + N, init, binary_op); + out[4] = joint_reduce(sg, global_ptr(in), + global_ptr(in) + N, binary_op); + out[5] = + joint_reduce(sg, global_ptr(in), + global_ptr(in) + N, init, binary_op); }); }); } diff --git a/sycl/test-e2e/KernelFusion/internalize_vfunc.cpp b/sycl/test-e2e/KernelFusion/internalize_vfunc.cpp index 0c78fa91cc2bb..9f598740e0577 100644 --- a/sycl/test-e2e/KernelFusion/internalize_vfunc.cpp +++ b/sycl/test-e2e/KernelFusion/internalize_vfunc.cpp @@ -46,11 +46,14 @@ int main() { cgh.parallel_for(numVec, [=](id<1> i) { size_t offset = i; vec in1; - in1.load(offset, accIn1.get_pointer()); + in1.load(offset, + accIn1.template get_multi_ptr()); vec in2; - in2.load(offset, accIn2.get_pointer()); + in2.load(offset, + accIn2.template get_multi_ptr()); auto tmp = in1 + in2; - tmp.store(offset, accTmp.get_pointer()); + tmp.store(offset, + accTmp.template get_multi_ptr()); }); }); @@ -62,11 +65,14 @@ int main() { cgh.parallel_for(numVec, [=](id<1> i) { size_t offset = i; vec tmp; - tmp.load(offset, accTmp.get_pointer()); + tmp.load(offset, + accTmp.template get_multi_ptr()); vec in3; - in3.load(offset, accIn3.get_pointer()); + in3.load(offset, + accIn3.template get_multi_ptr()); auto out = tmp * in3; - out.store(offset, accOut.get_pointer()); + out.store(offset, + accOut.template get_multi_ptr()); }); }); diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_bf16_impl.hpp index ec00fee588c67..095df9842b487 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_bf16_impl.hpp @@ -60,10 +60,11 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + make_bf16(2); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -93,10 +94,11 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - make_bf16(2); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -126,10 +128,11 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * make_bf16(3.0); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -159,10 +162,11 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / make_bf16(2.0); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -210,10 +214,11 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } } } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_half_impl.hpp index 62dc0bf55b359..fb1ed03757c4b 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_half_impl.hpp @@ -46,10 +46,11 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + static_cast(2); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -79,10 +80,11 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - static_cast(2); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -112,10 +114,11 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -145,10 +148,11 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -197,10 +201,11 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } } } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_int8_impl.hpp index d9e7871451165..0694f8ab737a5 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_int8_impl.hpp @@ -45,10 +45,11 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -78,10 +79,11 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - 2; } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -111,10 +113,11 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * 3; } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -144,10 +147,11 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / 2; } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); @@ -192,10 +196,11 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } } } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(read_only), ref); diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_int8_packed_impl.hpp index c93aa34d57f1e..6171b64aa99d7 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_all_ops_int8_packed_impl.hpp @@ -45,10 +45,11 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] + 2; } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + joint_matrix_store( + sg, sub_b, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufB.get_host_access(read_only), ref); @@ -78,10 +79,11 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] - 2; } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + joint_matrix_store( + sg, sub_b, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufB.get_host_access(read_only), ref); @@ -111,10 +113,11 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] * 3; } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + joint_matrix_store( + sg, sub_b, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufB.get_host_access(read_only), ref); @@ -144,10 +147,11 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] / 2; } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + joint_matrix_store( + sg, sub_b, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufB.get_host_access(read_only), ref); @@ -192,10 +196,11 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } } } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + joint_matrix_store( + sg, sub_b, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); }); // parallel for }).wait(); assert_ops_ref(bufB.get_host_access(read_only), ref); diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp index 3c34ca2d40f8d..1d1fbb9d01dbe 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp @@ -47,10 +47,11 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { joint_matrix sub_b(sg); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (global_idx * (TK / 4) * N) + - sg_starty / SG_SZ * TN * 4, - N, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, + N, matrix_layout::packed_b); // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 25703224bf3a6..e38e8521437b9 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -59,29 +59,34 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] *= 2; } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index d37911400318d..0f57377c571ac 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -83,29 +83,34 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_b( sg); joint_matrix sub_c(sg); - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K; k += TK) { - joint_matrix_load(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * K + k, K, - matrix_layout::row_major); + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k, + K, matrix_layout::row_major); // Assume we alreay in vnni format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k) * (N) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k) * (N) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 5.0; } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 36052ec1a015a..9868aef0d92e2 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -59,25 +59,30 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_b( sg); joint_matrix sub_c(sg); - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K; k += TK) { - joint_matrix_load(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * K + k, K, - matrix_layout::row_major); + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k, + K, matrix_layout::row_major); // Assume we alreay in vnni format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k) * (N) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k) * (N) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index 8045e038b1754..e777013205bf4 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -51,25 +51,30 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_b(sg); joint_matrix sub_c(sg); - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index b36cfaa29ddd6..bf08b7431fab0 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -46,24 +46,29 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_b(sg); joint_matrix sub_c(sg); - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (k * TK) * M + sg_startx * TM, + sg, sub_a, + accA.template get_multi_ptr() + + (k * TK) * M + sg_startx * TM, M, matrix_layout::col_major); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + - (sg_starty / SG_SZ * TN) * K + k * TK, - K, matrix_layout::col_major); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (sg_starty / SG_SZ * TN) * K + k * TK, + K, matrix_layout::col_major); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index e4476e33ec08a..56f89beaf8449 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -51,25 +51,30 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_b(sg); joint_matrix sub_c(sg); - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index eafa912d3fb77..7e64f68f778de 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -46,24 +46,29 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_b(sg); joint_matrix sub_c(sg); - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK) * (N) + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK) * (N) + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index 7d1001de112c6..b6d49b453665f 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -56,25 +56,30 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_b(sg); joint_matrix sub_c(sg); - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 2bb4352de7cb2..5fc9e3c2dfd02 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -63,18 +63,22 @@ void matrix_multiply(big_matrix &C, joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (k * TK) * M + sg_startx * TM, + sg, sub_a, + accA.template get_multi_ptr() + + (k * TK) * M + sg_startx * TM, M, matrix_layout::col_major); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + - (sg_starty / SG_SZ * TN) * K + k * TK, - K, matrix_layout::col_major); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (sg_starty / SG_SZ * TN) * K + k * TK, + K, matrix_layout::col_major); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index bf1582759778b..8189ce6f11205 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -55,19 +55,23 @@ void matrix_multiply(big_matrix &C, joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // VNNI transform is done automatically at this level - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index cb4ceacb0f0c3..4ac252224e520 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -80,25 +80,30 @@ void matrix_multiply(big_matrix &C, myparams2::joint_matrix_b sub_b(sg); myparams2::joint_matrix_c sub_c(sg); - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index d83332bfb44d6..cd9089d6de105 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -60,19 +60,23 @@ void matrix_multiply(big_matrix &C, joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 5fa2818533eb7..2655f68a2b4f9 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -59,25 +59,30 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index d0949f0d866f0..9b3b518c4fbe8 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -61,25 +61,30 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 2f6e8ef6a55e5..4f3a4968797bd 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -59,25 +59,30 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_bf16_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_bf16_impl.hpp index 33bcb5286abe0..ef0dec5cedbed 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_bf16_impl.hpp @@ -58,8 +58,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -93,8 +93,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -127,8 +127,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -162,8 +162,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -215,8 +215,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index 96031d7b0f887..1f8d927d06a17 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -49,8 +49,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -84,8 +84,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -119,8 +119,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -154,8 +154,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -208,8 +208,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index 45f91d4c67334..3feae0260fd8a 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -48,8 +48,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -83,8 +83,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -118,8 +118,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -153,8 +153,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -203,8 +203,9 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index 515ddf7b9677f..3f04866fb24ba 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -50,8 +50,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, N * 4); }); // parallel for }).wait(); @@ -87,8 +87,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, N * 4); }); // parallel for }).wait(); @@ -124,8 +124,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, N * 4); }); // parallel for }).wait(); @@ -161,8 +161,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, + accA.template get_multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, N * 4); }); // parallel for }).wait(); @@ -213,8 +213,9 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, + accA.template get_multi_ptr() + multi_ptr() + + (sg_startx * TM) * N * 4 + sg_starty / SG_SZ * TN * 4, N * 4); }); // parallel for }).wait(); diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp index 77a84c6533b23..d7d1d7a747c1e 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp @@ -51,8 +51,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -86,8 +86,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -120,8 +120,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -155,8 +155,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); @@ -205,8 +205,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index 7062d5dadc037..5dd2e1e4807f8 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -49,10 +49,11 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { ext::intel::experimental::matrix::layout::packed> sub_b; - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (global_idx * (TK / 4) * N) + - sg_starty / SG_SZ * TN * 4, - N); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, + N); // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index abf23d5b3719f..67f32c967f789 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -57,18 +57,22 @@ void matrix_multiply(big_matrix &C, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = @@ -76,10 +80,11 @@ void matrix_multiply(big_matrix &C, for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] *= 2; } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 3355512a0ecac..8e2865de207b4 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -84,18 +84,23 @@ void matrix_multiply(big_matrix &C, ext::intel::experimental::matrix::layout::packed> sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K; k += TK) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k, K); + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k, + K); // Assume we alreay in vnni format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k) * (N) + - sg_starty / SG_SZ * TN * 2, - N * 2); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k) * (N) + sg_starty / SG_SZ * TN * 2, + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = @@ -103,10 +108,11 @@ void matrix_multiply(big_matrix &C, for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 5.0; } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/get_coord_bf16_gemm_impl.hpp b/sycl/test-e2e/Matrix/get_coord_bf16_gemm_impl.hpp index 0d72dd6a04679..c154ff9f7d36a 100644 --- a/sycl/test-e2e/Matrix/get_coord_bf16_gemm_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_bf16_gemm_impl.hpp @@ -77,24 +77,29 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); float sum_local_rows[M] = {0}; // 8 local rows, M total auto data = diff --git a/sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp b/sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp index 3d10598fa27a1..e1a0b2411973c 100644 --- a/sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp @@ -117,7 +117,7 @@ void matrix_sum_rows(queue q, big_matrix &A, nd_range<2> &r) { sub_a; joint_matrix_load( - sg, sub_a, accA.get_pointer() + (global_idx * TM * K) + TK, + sg, sub_a, accA.template get_multi_ptr() + (global_idx * TM * K) + TK, K); // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_a diff --git a/sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp index 8326b986b1f05..76a8968239ced 100644 --- a/sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp @@ -139,10 +139,11 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { ext::intel::experimental::matrix::layout::packed> sub_b; - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (global_idx * (TK / 4) * N) + - sg_starty / SG_SZ * TN * 4, - N); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, + N); int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index 578cea797c6ef..a64a51da5182d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -61,25 +61,30 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + - (k * TK / vnniFactor) * (N * vnniFactor) + - sg_starty / SG_SZ * TN * vnniFactor, - N * vnniFactor); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / vnniFactor) * (N * vnniFactor) + + sg_starty / SG_SZ * TN * vnniFactor, + N * vnniFactor); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_bf16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_bf16_impl.hpp index 50752ad774e9f..ceb90ff03b77f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_bf16_impl.hpp @@ -49,7 +49,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + accA.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N); }); // parallel for }).wait(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index cb7efd95d8532..4303239cefe32 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -73,11 +73,11 @@ void matrix_verify_lambda(queue q, sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - joint_matrix_store(sg, sub_c, - accC.get_pointer() + - (sg_startx * M) * (N * nWGperDim) + - sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); }); // parallel for }); } @@ -156,11 +156,11 @@ void matrix_verify_op(queue q, big_matrix &C, sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - joint_matrix_store(sg, sub_c, - accC.get_pointer() + - (sg_startx * M) * (N * nWGperDim) + - sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); }); // parallel for }) .wait(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index 4b689f31c799e..cc0196660744a 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -51,25 +51,30 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index fcaed2239e642..d34dd5bc2d71e 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -49,23 +49,29 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (k * TK) * M + sg_startx * TM, + sg, sub_a, + accA.template get_multi_ptr() + + (k * TK) * M + sg_startx * TM, M); joint_matrix_load( sg, sub_b, - accB.get_pointer() + (sg_starty / SG_SZ * TN) * K + k * TK, K); + accB.template get_multi_ptr() + + (sg_starty / SG_SZ * TN) * K + k * TK, + K); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index 9f1e615e33ce4..4dfb4b929041c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -51,24 +51,29 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 6610970ee4e30..48eebf4dc749c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -49,24 +49,29 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK) * (N) + - sg_starty / SG_SZ * TN, - N); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK) * (N) + sg_starty / SG_SZ * TN, + N); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 425038705c917..a19263c403226 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -145,18 +145,24 @@ void test(queue &q) { M, N> sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (m * M) * Big_N + n * N, - Big_N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (m * M) * Big_N + n * N, + Big_N, layout::row_major); // k = row/col id of current submatrix of BIG A/B matrices for (int k = 0; k < Sub_Tiles_K; k++) { - joint_matrix_load(sg, sub_a, - accA.get_pointer() + (k * K) + (m * M * Big_K), - Big_K); - - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * K * Big_N) + (n * N), - Big_N); + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (k * K) + (m * M * Big_K), + Big_K); + + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * K * Big_N) + (n * N), + Big_N); // round values to correct precision if using tf32 if constexpr (std::is_same::value) { @@ -172,9 +178,11 @@ void test(queue &q) { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accD.get_pointer() + (m * M) * Big_N + n * N, - Big_N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accD.template get_multi_ptr() + + (m * M) * Big_N + n * N, + Big_N, layout::row_major); }); }); q.wait(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index bada4fdc7c09f..453d217a6a61d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -56,24 +56,29 @@ void matrix_multiply(big_matrix &C, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2); + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index d9e01baba219b..51c3de6b25bc8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -56,17 +56,22 @@ void matrix_multiply(big_matrix &C, joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (k * TK) * M + sg_startx * TM, + sg, sub_a, + accA.template get_multi_ptr() + + (k * TK) * M + sg_startx * TM, M); joint_matrix_load( sg, sub_b, - accB.get_pointer() + (sg_starty / SG_SZ * TN) * K + k * TK, K); + accB.template get_multi_ptr() + + (sg_starty / SG_SZ * TN) * K + k * TK, + K); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp index e8973ba154efb..9bf1636074f09 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -57,18 +57,23 @@ void matrix_multiply(big_matrix &C, joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); // VNNI transform is done automatically at this level joint_matrix_load( sg, sub_b, - accB.get_pointer() + (k * TK) * N + sg_starty / SG_SZ * TN, N); + accB.template get_multi_ptr() + + (k * TK) * N + sg_starty / SG_SZ * TN, + N); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index 55dd36e019854..5e4c8250b3b4b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -82,25 +82,30 @@ void matrix_multiply(big_matrix &C, sub_b; myparams2::joint_matrix_accumulator sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 172789f085e02..4042cc1730d7f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -60,18 +60,22 @@ void matrix_multiply(big_matrix &C, joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 2b6cb314f1676..faeb2ca7b12b1 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -57,24 +57,29 @@ void matrix_multiply(big_matrix &C, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index 0c6787ecd0ba9..cd4b5d05672ae 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -56,17 +56,23 @@ void matrix_multiply(big_matrix &C, layout::row_major> sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); joint_matrix_fill(sg, sub_a, 42); for (int k = 0; k < K; k += TK) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k, K); + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k, + K); joint_matrix_load( sg, sub_b, - accB.get_pointer() + (k) * (N) + sg_starty / SG_SZ * TN, N); + accB.template get_multi_ptr() + + (k) * (N) + sg_starty / SG_SZ * TN, + N); // If no rounding to tf32 function is called, joint_matrix_mad // function will work on truncated floats. joint_matrix_apply(sg, sub_a, @@ -81,10 +87,11 @@ void matrix_multiply(big_matrix &C, auto wi_slice_a = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; }); - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 694787f408a5e..5eb63fac8075d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -59,25 +59,30 @@ void matrix_multiply(big_matrix &C, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index ce42d45d0107c..62bad8422833e 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -57,25 +57,30 @@ void matrix_multiply(big_matrix &C, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/UserDefinedReductions/user_defined_reductions.cpp b/sycl/test-e2e/UserDefinedReductions/user_defined_reductions.cpp index 05258442b5ec0..987abff248e95 100644 --- a/sycl/test-e2e/UserDefinedReductions/user_defined_reductions.cpp +++ b/sycl/test-e2e/UserDefinedReductions/user_defined_reductions.cpp @@ -122,8 +122,9 @@ void test(queue q, InputContainer input, OutputContainer output, sycl::ext::oneapi::experimental::group_with_scratchpad( it.get_group(), sycl::span(&scratch[0], temp_memory_size)); - InputT *first = in.get_pointer(); - InputT *last = first + N; + const InputT *first = + in.template get_multi_ptr(); + const InputT *last = first + N; // check reduce_over_group w/o init out[0] = sycl::ext::oneapi::experimental::reduce_over_group( handle, in[it.get_global_id(0)], binary_op); diff --git a/sycl/test-e2e/UserDefinedReductions/user_defined_reductions_wg_size_larger_than_data_size.cpp b/sycl/test-e2e/UserDefinedReductions/user_defined_reductions_wg_size_larger_than_data_size.cpp index 7edeeab387f50..c3da25a7edecf 100644 --- a/sycl/test-e2e/UserDefinedReductions/user_defined_reductions_wg_size_larger_than_data_size.cpp +++ b/sycl/test-e2e/UserDefinedReductions/user_defined_reductions_wg_size_larger_than_data_size.cpp @@ -45,8 +45,11 @@ void test(queue q, InputContainer input, OutputContainer output, auto scratch = sycl::local_accessor(temp_memory_size, cgh); cgh.parallel_for( nd_range<1>(workgroup_size, workgroup_size), [=](nd_item<1> it) { - InputT *segment_begin = in.get_pointer(); - InputT *segment_end = in.get_pointer() + segment_size; + const InputT *segment_begin = + in.template get_multi_ptr(); + const InputT *segment_end = + in.template get_multi_ptr() + + segment_size; auto handle = sycl::ext::oneapi::experimental::group_with_scratchpad( it.get_group(), sycl::span(&scratch[0], temp_memory_size)); diff --git a/sycl/test-e2e/XPTI/kernel/content.cpp b/sycl/test-e2e/XPTI/kernel/content.cpp index 94a04d464566a..7003c2440f256 100644 --- a/sycl/test-e2e/XPTI/kernel/content.cpp +++ b/sycl/test-e2e/XPTI/kernel/content.cpp @@ -50,9 +50,12 @@ int main() { cgh.parallel_for( nd_range<3>({128, 4, 2}, {32, 2, 1}, {16, 1, 0}), [=](nd_item<3> it) { auto sg = it.get_sub_group(); - joint_exclusive_scan(sg, in.get_pointer(), - in.get_pointer() + sg.get_local_id(), - out.get_pointer(), std::plus<>{}); + joint_exclusive_scan( + sg, in.template get_multi_ptr(), + in.template get_multi_ptr() + + sg.get_local_id(), + out.template get_multi_ptr(), + std::plus<>{}); }); }); } diff --git a/sycl/test/basic_tests/accessor/accessor_get_pointer.cpp b/sycl/test/basic_tests/accessor/accessor_get_pointer.cpp index acb63f9ea33d8..ec24a3444b39b 100644 --- a/sycl/test/basic_tests/accessor/accessor_get_pointer.cpp +++ b/sycl/test/basic_tests/accessor/accessor_get_pointer.cpp @@ -12,16 +12,23 @@ void test_get_multi_ptr(handler &cgh, buffer &buffer) { using target_local_accessor_t = accessor; using local_accessor_t = local_accessor; + using accessor_t = + accessor; auto acc = buffer.get_access(cgh); auto target_local_acc = target_local_accessor_t({size}, cgh); auto local_acc = local_accessor_t({size}, cgh); + auto device_acc = + buffer.get_access(cgh); auto acc_ptr = acc.get_pointer(); auto target_local_ptr = target_local_acc.get_pointer(); auto local_pointer = local_acc.get_pointer(); + auto device_acc_ptr = device_acc.get_pointer(); static_assert(std::is_same_v>); static_assert(std::is_same_v>); static_assert( std::is_same_v>); -} \ No newline at end of file + static_assert( + std::is_same_v>); +} diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index d6d59135e6b56..f7da4ab4f83cd 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -47,21 +47,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -77,21 +83,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -107,21 +119,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -137,21 +155,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -167,21 +191,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -197,21 +227,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); }); diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 203642a7fc674..13094c3d78484 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -57,21 +57,27 @@ int main() { //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1f64(double addrspace(1)* %{{.*}}, i32 8) //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8) - joint_matrix_load(sg, sub_c, accC.get_pointer(), N, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + N, layout::row_major); //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1f64(double addrspace(1)* %{{.*}}, i32 4) //CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4) - joint_matrix_load(sg, sub_a, accA.get_pointer(), K); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + K); //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1f64(double addrspace(1)* %{{.*}}, i32 8) //CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8) - joint_matrix_load(sg, sub_b, accB.get_pointer(), N); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + N); //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double %{{.*}}, double %{{.*}}, double %{{.*}}, double {{.*}} //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1f64(double addrspace(1)* %{{.*}}, double %{{.*}}, double %{{.*}}, i32 8) //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) - joint_matrix_store(sg, sub_c, accD.get_pointer(), N, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + N, layout::row_major); }); cgh.parallel_for( @@ -87,21 +93,27 @@ int main() { //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1f64(double addrspace(1)* %{{.*}}, i32 8) //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8) - joint_matrix_load(sg, sub_c, accC.get_pointer(), M, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + M, layout::col_major); //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1f64(double addrspace(1)* %{{.*}}, i32 8) //CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 8) - joint_matrix_load(sg, sub_a, accA.get_pointer(), M); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + M); //CHECK: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1f64(double addrspace(1)* %{{.*}}, i32 4) //CHECK-OPAQUE: tail call double @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, i32 4) - joint_matrix_load(sg, sub_b, accB.get_pointer(), K); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + K); //CHECK: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double %{{.*}}, double %{{.*}}, double %{{.*}}, double {{.*}} //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1f64(double addrspace(1)* %{{.*}}, double %{{.*}}, double %{{.*}}, i32 8) //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) - joint_matrix_store(sg, sub_c, accD.get_pointer(), M, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + M, layout::col_major); }); }); diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 5ba83f853ae93..7f4c5e30ab3bf 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -46,21 +46,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -76,21 +82,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -106,21 +118,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -136,21 +154,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -166,21 +190,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -196,21 +226,27 @@ int main() { // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float {{.*}} // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); }); diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index 2304454fbaf6d..d651b6ad9b926 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -46,21 +46,27 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> {{.*}} // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0i32(i32* %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -76,21 +82,27 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> {{.*}} // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0i32(i32* %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -106,21 +118,27 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> {{.*}} // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0i32(i32* %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -136,21 +154,27 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> {{.*}} // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0i32(i32* %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -166,21 +190,27 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> {{.*}} // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0i32(i32* %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -196,21 +226,27 @@ int main() { // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.f16.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.f16.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> {{.*}} // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0i32(i32* %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); }); diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 1bbbab4b16629..da075d5bbfd87 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -46,21 +46,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -76,21 +82,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -106,21 +118,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -136,21 +154,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -166,21 +190,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -196,21 +226,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.s8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.s8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); }); diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index 43bcc84440690..9ae46e6256bed 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -69,14 +69,19 @@ int main() { //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %{{.*}}, i32 8) //CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0(ptr %{{.*}}, i32 8) - joint_matrix_load(sg, sub_a, accA.get_pointer(), K); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + K); //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32.p0i32(i32* %{{.*}}, i32 16) //CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), N); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + N); //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %{{.*}}, i32 16) //CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), N, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + N, layout::row_major); // CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} // Round a, b to tf32 @@ -92,8 +97,9 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} - joint_matrix_store(sg, sub_c, accD.get_pointer(), N, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + N, layout::row_major); }); }); @@ -118,14 +124,19 @@ int main() { //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0i32(i32* %{{.*}}, i32 8) //CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p0(ptr %{{.*}}, i32 8) - joint_matrix_load(sg, sub_a, accA.get_pointer(), K); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + K); //CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p0i32(i32* %{{.*}}, i32 16) //CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), N); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + N); //CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, i32 {{.*}}) //CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1(ptr addrspace(1) {{.*}}, i32 {{.*}}) - joint_matrix_load(sg, sub_c, accC.get_pointer(), N, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + N, layout::col_major); // CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}} // Round a, b to tf32 @@ -141,8 +152,9 @@ int main() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), N, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + N, layout::col_major); }); }); diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 16f357fc1dbf4..3fb078ab306dd 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -46,21 +46,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -76,21 +82,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -106,21 +118,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -136,21 +154,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); cgh.parallel_for( @@ -166,21 +190,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::row_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::row_major); // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::row_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::row_major); }); cgh.parallel_for( @@ -196,21 +226,27 @@ int main() { // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 16) - joint_matrix_load(sg, sub_c, accC.get_pointer(), stride, - layout::col_major); + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + stride, layout::col_major); // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.u8.p0i32(i32* %{{.*}}, i32 16) // CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.u8.p0(ptr %{{.*}}, i32 16) - joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + stride); // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 {{.*}} // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 16) // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) - joint_matrix_store(sg, sub_c, accD.get_pointer(), stride, - layout::col_major); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + stride, layout::col_major); }); }); diff --git a/sycl/test/esimd/esimd_verify.cpp b/sycl/test/esimd/esimd_verify.cpp index 1de016a96364e..b42eecd35b860 100644 --- a/sycl/test/esimd/esimd_verify.cpp +++ b/sycl/test/esimd/esimd_verify.cpp @@ -7,7 +7,7 @@ using namespace sycl; using namespace sycl::ext::intel::esimd; -// CHECK-DAG: error: function 'sycl::_V1::multi_ptr<{{.+}}> sycl::_V1::accessor<{{.+}}>::get_pointer<{{.+}}>() const' is not supported in ESIMD context +// CHECK-DAG: error: function 'int* sycl::_V1::accessor<{{.+}}>::get_pointer<{{.+}}>() const' is not supported in ESIMD context // CHECK-DAG: error: function '{{.+}} sycl::_V1::accessor<{{.+}}>::operator[]<{{.+}}>({{.+}}) const' is not supported in ESIMD context // CHECK-DAG: error: function '{{.+}}combine(int const&)' is not supported in ESIMD context diff --git a/sycl/test/extensions/fpga.cpp b/sycl/test/extensions/fpga.cpp index aaec1d67fe2b5..16b56c2ab5ee0 100644 --- a/sycl/test/extensions/fpga.cpp +++ b/sycl/test/extensions/fpga.cpp @@ -12,7 +12,7 @@ template struct ethernet_pipe_id { template -void lsu_body(sycl::multi_ptr input_ptr, +void lsu_body(sycl::multi_ptr input_ptr, sycl::multi_ptr output_ptr) { using PrefetchingLSU = sycl::ext::intel::lsu, @@ -99,7 +99,7 @@ int main() { auto *in_ptr = sycl::malloc_host(1, Queue.get_context()); Queue.submit([&](sycl::handler &cgh) { cgh.single_task([=]() { - sycl::host_ptr input_ptr(in_ptr); + sycl::host_ptr input_ptr(in_ptr); sycl::host_ptr output_ptr(out_ptr); intelfpga::lsu_body< int, sycl::access::address_space::ext_intel_global_host_space>( @@ -112,7 +112,7 @@ int main() { auto *in_ptr = sycl::malloc_device(1, Queue); Queue.submit([&](sycl::handler &cgh) { cgh.single_task([=]() { - sycl::ext::intel::device_ptr input_ptr(in_ptr); + sycl::ext::intel::device_ptr input_ptr(in_ptr); sycl::ext::intel::device_ptr output_ptr(out_ptr); intelfpga::lsu_body< int, sycl::access::address_space::ext_intel_global_device_space>( @@ -129,8 +129,8 @@ int main() { auto input_accessor = input_buffer.get_access(cgh); cgh.single_task([=]() { - auto input_ptr = input_accessor.get_pointer(); - auto output_ptr = output_accessor.get_pointer(); + auto input_ptr = sycl::global_ptr(input_accessor); + auto output_ptr = sycl::global_ptr(output_accessor); intelfpga::lsu_body<>(input_ptr, output_ptr); }); }); diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index 541f4f75a4c71..f3e7d441a8097 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -72,25 +72,30 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index 447e880afecbd..ce068fa7eb3b5 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -71,25 +71,30 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index b87b63e29cc98..de5980206dfc9 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -70,25 +70,30 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index 01efab63246d0..562e1b440de0b 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -71,29 +71,34 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { wi_data_c[i] *= 2; } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index ed6408f74076a..f8c76ee75129e 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -71,25 +71,30 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index a0c2edb62c2f1..6efc0e89b0a57 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -78,19 +78,23 @@ void matrix_multiply(big_matrix &C, joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K, matrix_layout::row_major); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index da034be7a8fc7..2f7424ac33525 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -158,10 +158,11 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { ext::intel::experimental::matrix::layout::packed> sub_b; - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (global_idx * (TK / 4) * N) + - sg_starty / SG_SZ * TN * 4, - N); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, + N); int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 1a0178165971b..2e0e309081464 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -72,25 +72,30 @@ void matrix_multiply(big_matrix &C, sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 0174d60866958..3205e4c346ba6 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -76,19 +76,23 @@ void matrix_multiply(big_matrix &C, // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_data_c = @@ -96,10 +100,11 @@ void matrix_multiply(big_matrix &C, for (int i = 0; i < wi_data_c.length(); i++) { wi_data_c[i] *= 2; } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index de8721bca3b09..63866c19f89fa 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -84,19 +84,23 @@ void matrix_multiply(big_matrix &C, joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 4) * (N * 4) + - sg_starty / SG_SZ * TN * 4, - N * 4); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index bfea2e4c9a698..d6affb4067003 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -66,17 +66,23 @@ void matrix_multiply(big_matrix &C, layout::row_major> sub_b; joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); joint_matrix_fill(sg, sub_a, 42); for (int k = 0; k < K; k += TK) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k, K); + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k, + K); joint_matrix_load( sg, sub_b, - accB.get_pointer() + (k) * (N) + sg_starty / SG_SZ * TN, N); + accB.template get_multi_ptr() + + (k) * (N) + sg_starty / SG_SZ * TN, + N); // If no rounding to tf32 function is called, joint_matrix_mad // function will work on truncated floats. joint_matrix_apply(sg, sub_a, @@ -88,10 +94,11 @@ void matrix_multiply(big_matrix &C, } sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/sycl/test/multi_ptr/ctad.cpp b/sycl/test/multi_ptr/ctad.cpp index bbf80fb1bf08a..792447bab3649 100644 --- a/sycl/test/multi_ptr/ctad.cpp +++ b/sycl/test/multi_ptr/ctad.cpp @@ -37,4 +37,9 @@ int main() { static_assert(std::is_same::value); static_assert(std::is_same::value); static_assert(std::is_same::value); + + globlMPtr non_const_multi_ptr; + using constTypeMPtr = sycl::multi_ptr; + auto constTypeMultiPtr = constTypeMPtr(non_const_multi_ptr); } diff --git a/sycl/test/warnings/sycl_2020_deprecations.cpp b/sycl/test/warnings/sycl_2020_deprecations.cpp index d8475b1cddf78..38d65481628eb 100644 --- a/sycl/test/warnings/sycl_2020_deprecations.cpp +++ b/sycl/test/warnings/sycl_2020_deprecations.cpp @@ -170,7 +170,7 @@ int main() { // expected-error@+2{{no member named 'ONEAPI' in namespace 'sycl'}} // expected-error@+2{{no member named 'ONEAPI' in namespace 'sycl'}} sycl::ext::oneapi::atomic_fence(sycl::ONEAPI::memory_order::relaxed, - sycl::ONEAPI::memory_scope::work_group); + sycl::ONEAPI::memory_scope::work_group); // expected-error@+1{{no member named 'INTEL' in namespace 'sycl'}} auto SL = sycl::INTEL::source_language::opencl_c;