Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
  • Loading branch information
MrSidims committed Aug 8, 2024
1 parent f2b7564 commit 5f83083
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 25 deletions.
6 changes: 3 additions & 3 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ template <typename Ts, typename T, std::size_t R, std::size_t C,
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
Ts val, size_t i);
#else // __SPIRV_USE_COOPERATIVE_MATRIX
#else // __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down Expand Up @@ -262,10 +262,10 @@ extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> *, size_t i);


// AccessChain followed by load/store serves to extract/insert and element
// from/to the matrix
template <typename Ts, typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL Ts *
__spirv_AccessChain(__spv::__spirv_CooperativeMatrixKHR<T, S, R, C, U> **,
Expand Down
71 changes: 49 additions & 22 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,15 @@ class wi_element {
rhs.idx),
idx);
#else
storage_element_type *ExtractP = __spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(
&rhs.M.spvm, rhs.idx);
storage_element_type *InsertP = __spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value, spv_scope_traits<Group>::value>(&M.spvm, idx);
storage_element_type *ExtractP =
__spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&rhs.M.spvm,
rhs.idx);
storage_element_type *InsertP =
__spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
*InsertP = *ExtractP;
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
return *this;
Expand Down Expand Up @@ -244,11 +247,15 @@ class wi_element {
#else // __SPIRV_USE_COOPERATIVE_MATRIX
#define OP(op) \
template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
storage_element_type *ExtractP = __spirv_AccessChain<storage_element_type, T, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>( \
&rhs.M.spvm, rhs.idx); \
storage_element_type *InsertP = __spirv_AccessChain<storage_element_type, T, NumRows, NumCols, spv_matrix_use_traits<Use>::value, spv_scope_traits<Group>::value>(&M.spvm, idx); \
storage_element_type *ExtractP = \
__spirv_AccessChain<storage_element_type, T, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&rhs.M.spvm, \
rhs.idx); \
storage_element_type *InsertP = \
__spirv_AccessChain<storage_element_type, T, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&M.spvm, idx); \
*InsertP = *ExtractP op static_cast<storage_element_type>(rhs); \
return *this; \
}
Expand Down Expand Up @@ -313,7 +320,8 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
spv_scope_traits<Group>::value>(M.spvm, idx);
#else
sycl::ext::oneapi::bfloat16 *ExtractP =
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols,
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
return *ExtractP;
Expand All @@ -336,7 +344,8 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
std::numeric_limits<float>::epsilon();
#else
sycl::ext::oneapi::bfloat16 *ExtractP =
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols,
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
sycl::ext::oneapi::bfloat16 Elem = *ExtractP;
Expand All @@ -354,7 +363,11 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
#else
sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols, spv_matrix_use_traits<Use>::value, spv_scope_traits<Group>::value>(&M.spvm, idx);
sycl::ext::oneapi::bfloat16 *InsertP =
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
*InsertP = rhs;
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
return *this;
Expand All @@ -381,11 +394,16 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
return *this;
#else
sycl::ext::oneapi::bfloat16 *ExtractP =
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols,
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&rhs.M.spvm,
rhs.idx);
sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols, spv_matrix_use_traits<Use>::value, spv_scope_traits<Group>::value>(&M.spvm, idx);
sycl::ext::oneapi::bfloat16 *InsertP =
__spirv_AccessChain<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
*InsertP = *ExtractP;
return *this;
#endif // __SPIRV_USE_COOPERATIVE_MATRIX
Expand Down Expand Up @@ -414,10 +432,15 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#define OP(opassign, op) \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
sycl::ext::oneapi::bfloat16 *ExtractP = \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&M.spvm, idx); \
sycl::ext::oneapi::bfloat16 *InsertP = \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&M.spvm, idx); \
sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols, spv_matrix_use_traits<Use>::value, spv_scope_traits<Group>::value>(&M.spvm, idx); \
*InsertP = *ExtractP op rhs; \
return *this; \
}
Expand Down Expand Up @@ -466,7 +489,8 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
Layout, Group> &lhs, \
const sycl::ext::oneapi::bfloat16 &rhs) { \
sycl::ext::oneapi::bfloat16 *ExtractP = \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&lhs.M.spvm, \
lhs.idx); \
Expand All @@ -477,7 +501,8 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &rhs) { \
sycl::ext::oneapi::bfloat16 *ExtractP = \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&rhs.M.spvm, \
rhs.idx); \
Expand Down Expand Up @@ -522,7 +547,8 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
Layout, Group> &lhs, \
const sycl::ext::oneapi::bfloat16 &rhs) { \
sycl::ext::oneapi::bfloat16 *ExtractP = \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&lhs.M.spvm, \
lhs.idx); \
Expand All @@ -533,7 +559,8 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
Layout, Group> &rhs) { \
sycl::ext::oneapi::bfloat16 *ExtractP = \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&rhs.M.spvm, \
rhs.idx); \
Expand Down

0 comments on commit 5f83083

Please sign in to comment.