Skip to content

Commit

Permalink
Try to not use __spirv_Load/__spirv_Store
Browse files Browse the repository at this point in the history
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
  • Loading branch information
MrSidims committed Apr 8, 2024
1 parent 5b8e9eb commit c035f44
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ class wi_element {
__spirv_AccessChain<storage_element_type, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
storage_element_type elem = __spirv_Load<T>(ExtractP);
storage_element_type elem = *ExtractP;
// storage_element_type elem = __spirv_Load<T>(ExtractP);
#endif // USE_COOP_MATRIX
return elem;
#else
Expand All @@ -169,7 +170,8 @@ class wi_element {
__spirv_AccessChain<storage_element_type, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
return __spirv_Load<T>(ExtractP) != static_cast<storage_element_type>(0);
return *ExtractP != static_cast<storage_element_type>(0);
// return __spirv_Load<T>(ExtractP) != static_cast<storage_element_type>(0);
#endif // USE_COOP_MATRIX
#else
throw runtime_error("joint matrix is not supported on host device.",
Expand All @@ -184,7 +186,8 @@ class wi_element {
M.spvm, static_cast<storage_element_type>(rhs), idx);
#else
T2 *InsertP = __spirv_AccessChain(&M.spvm, idx);
__spirv_Store(InsertP, static_cast<storage_element_type>(rhs));
*InsertP = static_cast<storage_element_type>(rhs);
// __spirv_Store(InsertP, static_cast<storage_element_type>(rhs));
#endif // USE_COOP_MATRIX
return *this;
#else
Expand All @@ -210,9 +213,13 @@ class wi_element {
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(
&rhs.M.spvm, rhs.idx);
T *InsertP = __spirv_AccessChain(&M.spvm, idx);
*InsertP = *ExtractP;
/*
T RhsVal = __spirv_Load(ExtractP);
T *InsertP = __spirv_AccessChain(&M.spvm, idx);
__spirv_Store(InsertP, RhsVal);
*/
#endif // USE_COOP_MATRIX
return *this;
#else
Expand Down Expand Up @@ -245,10 +252,8 @@ class wi_element {
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>( \
&rhs.M.spvm, rhs.idx); \
T RhsVal = \
__spirv_Load(ExtractP) op static_cast<storage_element_type>(rhs); \
T *InsertP = __spirv_AccessChain(&M.spvm, idx); \
__spirv_Store(static_cast<storage_element_type>(InsertP), RhsVal); \
*InsertP = *ExtractP op static_cast<storage_element_type>(rhs); \
return *this; \
}
#endif // USE_COOP_MATRIX
Expand Down Expand Up @@ -315,7 +320,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
return __spirv_Load<sycl::ext::oneapi::bfloat16>(ExtractP);
return *ExtractP;
#endif // USE_COOP_MATRIX
#else
throw runtime_error("joint matrix is not supported on host device.",
Expand All @@ -338,8 +343,8 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&M.spvm, idx);
sycl::ext::oneapi::bfloat16 Elem =
__spirv_Load<sycl::ext::oneapi::bfloat16>(ExtractP);
sycl::ext::oneapi::bfloat16 Elem = *ExtractP;
// __spirv_Load<sycl::ext::oneapi::bfloat16>(ExtractP);
return sycl::fabs(static_cast<float>(Elem)) >=
std::numeric_limits<float>::epsilon();
#endif // USE_COOP_MATRIX
Expand Down Expand Up @@ -384,9 +389,11 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
spv_matrix_use_traits<Use>::value,
spv_scope_traits<Group>::value>(&rhs.M.spvm,
rhs.idx);
sycl::ext::oneapi::bfloat16 RhsVal = __spirv_Load(ExtractP);
sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain(&M.spvm, idx);
__spirv_Store(InsertP, RhsVal);
*InsertP = *ExtractP;
/* sycl::ext::oneapi::bfloat16 RhsVal = __spirv_Load(ExtractP);
sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain(&M.spvm, idx);
__spirv_Store(InsertP, RhsVal);*/
#endif // USE_COOP_MATRIX
return *this;
#else
Expand Down Expand Up @@ -417,9 +424,8 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
__spirv_AccessChain<sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&M.spvm, idx); \
sycl::ext::oneapi::bfloat16 RhsVal = __spirv_Load(ExtractP) op rhs; \
sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain(&M.spvm, idx); \
__spirv_Store(InsertP, RhsVal); \
*InsertP = *ExtractP op rhs; \
return *this; \
}
#endif // USE_COOP_MATRIX
Expand Down Expand Up @@ -471,7 +477,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&lhs.M.spvm, \
lhs.idx); \
return __spirv_Load<sycl::ext::oneapi::bfloat16>(ExtractP) op rhs; \
return *ExtractP op rhs; \
} \
friend type operator op( \
const sycl::ext::oneapi::bfloat16 &lhs, \
Expand All @@ -482,7 +488,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&rhs.M.spvm, \
rhs.idx); \
return __spirv_Load<sycl::ext::oneapi::bfloat16>(ExtractP) op lhs; \
return *ExtractP op lhs; \
}
#endif // USE_COOP_MATRIX
OP(sycl::ext::oneapi::bfloat16, +)
Expand Down Expand Up @@ -527,8 +533,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&lhs.M.spvm, \
lhs.idx); \
return type{static_cast<float>(__spirv_Load<sycl::ext::oneapi::bfloat16>( \
ExtractP)) op static_cast<float>(rhs)}; \
return type{static_cast<float>(*ExtractP) op static_cast<float>(rhs)}; \
} \
friend type operator op( \
const sycl::ext::oneapi::bfloat16 &lhs, \
Expand All @@ -539,8 +544,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
spv_matrix_use_traits<Use>::value, \
spv_scope_traits<Group>::value>(&rhs.M.spvm, \
rhs.idx); \
return type{static_cast<float>(__spirv_Load<sycl::ext::oneapi::bfloat16>( \
ExtractP)) op static_cast<float>(lhs)}; \
return type{static_cast<float>(*ExtractP) op static_cast<float>(lhs)}; \
}
#endif // USE_COOP_MATRIX
OP(bool, ==)
Expand Down

0 comments on commit c035f44

Please sign in to comment.