@@ -67,26 +67,18 @@ static constexpr __device__ int get_mmq_y_device() {
6767#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
6868#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
6969
70- #define GET_MMQ_DP4A_TXS_BODY \
71- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : \
72- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : \
73- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 : \
74- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 : \
75- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : \
76- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : \
77- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : \
78- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : \
79- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : \
80- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : \
81- tile_x_sizes{0 , 0 , 0 }
82-
83- static tile_x_sizes mmq_get_dp4a_tile_x_sizes_host (const ggml_type type, const int mmq_y) {
84- GET_MMQ_DP4A_TXS_BODY;
85- }
86-
87- template <int mmq_y>
88- static constexpr __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes_device (ggml_type type) {
89- GET_MMQ_DP4A_TXS_BODY;
70+ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
71+ return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
72+ type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
73+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
74+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
75+ type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
76+ type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
77+ type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
78+ type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
79+ type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
80+ type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
81+ tile_x_sizes{0 , 0 , 0 };
9082}
9183
9284#define MMQ_MMA_TILE_X_K_Q4_0 (1 *WARP_SIZE + WARP_SIZE/QI4_0 + 4 )
@@ -111,21 +103,18 @@ static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
111103static_assert (MMQ_MMA_TILE_X_K_Q5_K % 8 == 4 , " Wrong padding." );
112104static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
113105
114- #define MMQ_MMA_GET_TILE_X_K_BODY \
115- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : \
116- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : \
117- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 : \
118- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 : \
119- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : \
120- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : \
121- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : \
122- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : \
123- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : \
124- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
125- 0
126-
127106static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
128- MMQ_MMA_GET_TILE_X_K_BODY;
107+ return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
108+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
109+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
110+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
111+ type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
112+ type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
113+ type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
114+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
115+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
116+ type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
117+ 0 ;
129118}
130119
131120#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
@@ -154,7 +143,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
154143 int * x_qs = (int *) x_tile;
155144 float * x_df = (float *) (x_qs + WARP_SIZE);
156145#else
157- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_0);
146+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_0, mmq_y );
158147 int * x_qs = (int *) x_tile;
159148 float * x_df = (float *) (x_qs + txs.qs );
160149#endif // INT8_MMA_AVAILABLE
@@ -204,7 +193,7 @@ template <int mmq_x, int mmq_y, int nwarps>
204193static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a (
205194 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
206195
207- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_0);
196+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_0, mmq_y );
208197 const int * x_qs = (const int *) x;
209198 const float * x_df = (const float *) x_qs + txs.qs ;
210199 const int * y_qs = (const int *) y + 4 ;
@@ -317,7 +306,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
317306 int * x_qs = (int *) x_tile;
318307 half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
319308#else
320- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_1);
309+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_1, mmq_y );
321310 int * x_qs = (int *) x_tile;
322311 half2 * x_dm = (half2 *) (x_qs + txs.qs );
323312#endif // INT8_MMA_AVAILABLE
@@ -367,7 +356,7 @@ template <int mmq_x, int mmq_y, int nwarps>
367356static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a (
368357 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
369358
370- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_1);
359+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_1, mmq_y );
371360 const int * x_qs = (const int *) x;
372361 const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
373362 const int * y_qs = (const int *) y + 4 ;
@@ -479,7 +468,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
479468 int * x_qs = (int *) x_tile;
480469 float * x_df = (float *) (x_qs + WARP_SIZE*2 );
481470#else
482- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_0);
471+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_0, mmq_y );
483472 int * x_qs = (int *) x_tile;
484473 float * x_df = (float *) (x_qs + txs.qs );
485474#endif // INT8_MMA_AVAILABLE
@@ -548,7 +537,7 @@ template <int mmq_x, int mmq_y, int nwarps>
548537static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a (
549538 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
550539
551- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_0);
540+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_0, mmq_y );
552541 const int * x_qs = (const int *) x;
553542 const float * x_df = (const float *) x_qs + txs.qs ;
554543 const int * y_qs = (const int *) y + 4 ;
@@ -644,7 +633,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
644633 int * x_qs = (int *) x_tile;
645634 half2 * x_dm = (half2 *) (x_qs + 2 *WARP_SIZE);
646635#else
647- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_1);
636+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_1, mmq_y );
648637 int * x_qs = (int *) x_tile;
649638 half2 * x_dm = (half2 *) (x_qs + txs.qs );
650639#endif // INT8_MMA_AVAILABLE
@@ -711,7 +700,7 @@ template <int mmq_x, int mmq_y, int nwarps>
711700static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a (
712701 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
713702
714- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_1);
703+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_1, mmq_y );
715704 const int * x_qs = (const int *) x;
716705 const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
717706 const int * y_qs = (const int *) y + 4 ;
@@ -808,7 +797,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
808797 int * x_qs = (int *) x_tile;
809798 float * x_df = (float *) (x_tile + WARP_SIZE);
810799#else
811- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q8_0);
800+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q8_0, mmq_y );
812801 int * x_qs = (int *) x_tile;
813802 float * x_df = (float *) (x_qs + txs.qs );
814803#endif // INT8_MMA_AVAILABLE
@@ -858,7 +847,7 @@ template <int mmq_x, int mmq_y, int nwarps>
858847static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a (
859848 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
860849
861- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q8_0);
850+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q8_0, mmq_y );
862851 const int * x_qs = (const int *) x;
863852 const float * x_df = (const float *) x_qs + txs.qs ;
864853 const int * y_qs = (const int *) y + 4 ;
@@ -954,7 +943,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
954943 int * x_qs = (int *) x_tile;
955944 half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
956945#else
957- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q2_K);
946+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q2_K, mmq_y );
958947 int * x_qs = (int *) x_tile;
959948 half2 * x_dm = (half2 *) (x_qs + txs.qs );
960949#endif // INT8_MMA_AVAILABLE
@@ -1013,7 +1002,7 @@ template <int mmq_x, int mmq_y, int nwarps>
10131002static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a (
10141003 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
10151004
1016- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q2_K);
1005+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q2_K, mmq_y );
10171006 const int * x_qs = (const int *) x;
10181007 const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
10191008 const int * y_qs = (const int *) y + 4 ;
@@ -1135,7 +1124,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
11351124 float * x_df = (float *) (x_qs + WARP_SIZE*2 );
11361125 int * x_sc = (int *) (x_df + WARP_SIZE/QI3_K);
11371126#else
1138- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q3_K);
1127+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q3_K, mmq_y );
11391128 int * x_qs = (int *) x_tile;
11401129 float * x_df = (float *) (x_qs + txs.qs );
11411130 int * x_sc = (int *) (x_df + txs.dm );
@@ -1233,7 +1222,7 @@ template <int mmq_x, int mmq_y, int nwarps>
12331222static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a (
12341223 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
12351224
1236- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q3_K);
1225+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q3_K, mmq_y );
12371226 const int * x_qs = (const int *) x;
12381227 const float * x_df = (const float *) x_qs + txs.qs ;
12391228 const int * x_sc = (const int *) x_df + txs.dm ;
@@ -1361,7 +1350,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
13611350 half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
13621351 int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
13631352#else
1364- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_K);
1353+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_K, mmq_y );
13651354 int * x_qs = (int *) x_tile;
13661355 half2 * x_dm = (half2 *) (x_qs + txs.qs );
13671356 int * x_sc = (int *) (x_dm + txs.dm );
@@ -1437,7 +1426,7 @@ template <int mmq_x, int mmq_y, int nwarps>
14371426static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a (
14381427 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
14391428
1440- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q4_K);
1429+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q4_K, mmq_y );
14411430 const int * x_qs = (const int *) x;
14421431 const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
14431432 const int * x_sc = (const int *) x_dm + txs.dm ;
@@ -1578,7 +1567,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
15781567 half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2 );
15791568 int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
15801569#else
1581- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_K);
1570+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_K, mmq_y );
15821571 int * x_qs = (int *) x_tile;
15831572 half2 * x_dm = (half2 *) (x_qs + txs.qs );
15841573 int * x_sc = (int *) (x_dm + txs.dm );
@@ -1668,7 +1657,7 @@ template <int mmq_x, int mmq_y, int nwarps>
16681657static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a (
16691658 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
16701659
1671- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q5_K);
1660+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q5_K, mmq_y );
16721661 const int * x_qs = (const int *) x;
16731662 const half2 * x_dm = (const half2 *) x_qs + txs.qs ;
16741663 const int * x_sc = (const int *) x_dm + txs.dm ;
@@ -1800,7 +1789,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
18001789 float * x_df = (float *) (x_qs + WARP_SIZE*2 );
18011790 int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
18021791#else
1803- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q6_K);
1792+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q6_K, mmq_y );
18041793 int * x_qs = (int *) x_tile;
18051794 float * x_df = (float *) (x_qs + txs.qs );
18061795 int * x_sc = (int *) (x_df + txs.dm );
@@ -1882,7 +1871,7 @@ template <int mmq_x, int mmq_y, int nwarps>
18821871static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a (
18831872 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
18841873
1885- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_device<mmq_y> (GGML_TYPE_Q6_K);
1874+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q6_K, mmq_y );
18861875 const int * x_qs = (const int *) x;
18871876 const float * x_df = (const float *) x_qs + txs.qs ;
18881877 const int * x_sc = (const int *) x_df + txs.dm ;
@@ -2422,7 +2411,7 @@ struct mmq_args {
24222411
24232412template <ggml_type type>
24242413static int mmq_get_shmem (const int mmq_x, const int mmq_y, const int cc) {
2425- const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host (type, mmq_y);
2414+ const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (type, mmq_y);
24262415 const int mmq_tile_x_k = mmq_get_mma_tile_x_k (type);
24272416 const int shmem_x = int8_mma_available (cc) ? mmq_y*mmq_tile_x_k*sizeof (int ) : txs.qs *sizeof (int ) + txs.dm *sizeof (half2) + txs.sc *sizeof (int );
24282417 const int shmem_y = mmq_x*sizeof (block_q8_1_mmq);
0 commit comments