2222#endif
2323
2424#ifdef __SYCL_DEVICE_ONLY__
25+
26+ #ifdef __SYCL_EXT_ONEAPI_MATRIX_USE__
27+ #define JOINT_MATRIX_INTEL (T, R, C, L, S, U ) \
28+ __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U>
29+ #else
30+ #define JOINT_MATRIX_INTEL (T, R, C, L, S, U ) \
31+ __spv::__spirv_JointMatrixINTEL<T, R, C, L, S>
32+ #endif // __SYCL_EXT_ONEAPI_MATRIX_USE__
33+
2534template <typename T, std::size_t R, std::size_t C,
2635 __spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
2736 __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
2837 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
29- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
38+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
3039__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
3140 __spv::MatrixLayout Layout = L,
3241 __spv::Scope::Flag Sc = S, int MemOperand = 0 );
@@ -36,7 +45,7 @@ template <typename T, std::size_t R, std::size_t C,
3645 __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
3746 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
3847extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL (
39- T *Ptr, __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *Object,
48+ T *Ptr, JOINT_MATRIX_INTEL( T, R, C, L, S, U) *Object,
4049 std::size_t Stride, __spv::MatrixLayout Layout = L,
4150 __spv::Scope::Flag Sc = S, int MemOperand = 0);
4251
@@ -48,11 +57,11 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
4857 __spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
4958 __spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
5059 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
51- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T2, M, N, LC, S, UC> *
60+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T2, M, N, LC, S, UC) *
5261__spirv_JointMatrixMadINTEL(
53- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
54- __spv::__spirv_JointMatrixINTEL< T1, K, N, LB, S, UB> *B,
55- __spv::__spirv_JointMatrixINTEL< T2, M, N, LC, S, UC> *C,
62+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
63+ JOINT_MATRIX_INTEL( T1, K, N, LB, S, UB) *B,
64+ JOINT_MATRIX_INTEL( T2, M, N, LC, S, UC) *C,
5665 __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
5766
5867template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -63,11 +72,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
6372 __spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
6473 __spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
6574 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
66- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3 , M, N, LC, S, UC> *
75+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL (T2 , M, N, LC, S, UC) *
6776__spirv_JointMatrixUUMadINTEL(
68- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
69- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
70- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
77+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
78+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
79+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
7180 __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
7281
7382template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -78,11 +87,11 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
7887 __spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
7988 __spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
8089 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
81- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *
90+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T3, M, N, LC, S, UC) *
8291__spirv_JointMatrixUSMadINTEL(
83- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
84- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
85- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
92+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
93+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
94+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
8695 __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
8796
8897template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
@@ -93,38 +102,39 @@ template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
93102 __spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
94103 __spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
95104 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
96- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *
105+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T3, M, N, LC, S, UC) *
97106__spirv_JointMatrixSUMadINTEL(
98- __spv::__spirv_JointMatrixINTEL< T1, M, K, LA, S, UA> *A,
99- __spv::__spirv_JointMatrixINTEL< T2, K, N, LB, S, UB> *B,
100- __spv::__spirv_JointMatrixINTEL< T3, M, N, LC, S, UC> *C,
107+ JOINT_MATRIX_INTEL ( T1, M, K, LA, S, UA) *A,
108+ JOINT_MATRIX_INTEL( T2, K, N, LB, S, UB) *B,
109+ JOINT_MATRIX_INTEL( T3, M, N, LC, S, UC) *C,
101110 __spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
102111
103112template <typename T, std::size_t R, std::size_t C,
104113 __spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
105114 __spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
106115 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
107- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
116+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
108117__spirv_CompositeConstruct(const T v);
109118
110119template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
111120 __spv::MatrixLayout L,
112121 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
113122extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL (
114- __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *);
123+ JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *);
115124
116125template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
117126 __spv::MatrixLayout L,
118127 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
119128extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic (
120- __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *, size_t i);
129+ JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *, size_t i);
121130
122131template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
123132 __spv::MatrixLayout L,
124133 __spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
125- extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *
126- __spirv_VectorInsertDynamic (__spv::__spirv_JointMatrixINTEL< T, R, C, L, S, U> *,
134+ extern SYCL_EXTERNAL JOINT_MATRIX_INTEL ( T, R, C, L, S, U) *
135+ __spirv_VectorInsertDynamic(JOINT_MATRIX_INTEL( T, R, C, L, S, U) *,
127136 T val, size_t i);
137+ #undef JOINT_MATRIX_INTEL
128138
129139#ifndef __SPIRV_BUILTIN_DECLARATIONS__
130140#error \
0 commit comments