Skip to content

Commit db6dae7

Browse files
CUDA: optimize MMQ int8 tensor core performance
1 parent 52fc870 commit db6dae7

File tree

2 files changed

+888
-542
lines changed

2 files changed

+888
-542
lines changed

ggml-cuda/mma.cuh

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
2020
GGML_CUDA_ASSUME(ret < K);
2121
return ret;
2222
}
23+
24+
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
25+
#if defined(INT8_MMA_AVAILABLE)
26+
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
27+
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
28+
: "+r"(x[0]), "+r"(x[1])
29+
: "l"(xs));
30+
#else
31+
#pragma unroll
32+
for (int l = 0; l < ne; ++l) {
33+
x[l] = xs0[get_i(l)*stride + get_k(l)];
34+
}
35+
#endif // defined(INT8_MMA_AVAILABLE)
36+
}
2337
};
2438

2539
struct mma_int_A_I16K8 {
@@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
4256
GGML_CUDA_ASSUME(ret < K);
4357
return ret;
4458
}
59+
60+
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
61+
#if defined(INT8_MMA_AVAILABLE)
62+
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
63+
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
64+
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
65+
: "l"(xs));
66+
#else
67+
#pragma unroll
68+
for (int l = 0; l < ne; ++l) {
69+
x[l] = xs0[get_i(l)*stride + get_k(l)];
70+
}
71+
#endif // defined(INT8_MMA_AVAILABLE)
72+
}
4573
};
4674

4775
struct mma_int_B_J8K4 {
@@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
6492
GGML_CUDA_ASSUME(ret < K);
6593
return ret;
6694
}
95+
96+
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
97+
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
98+
const int * xs = xs0 + (threadIdx.x%J)*stride;
99+
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
100+
: "+r"(x[0])
101+
: "l"(xs));
102+
#else
103+
#pragma unroll
104+
for (int l = 0; l < ne; ++l) {
105+
x[l] = xs0[get_j(l)*stride + get_k(l)];
106+
}
107+
#endif // defined(INT8_MMA_AVAILABLE)
108+
}
67109
};
68110

69111
struct mma_int_B_J8K8 {
@@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
86128
GGML_CUDA_ASSUME(ret < K);
87129
return ret;
88130
}
131+
132+
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
133+
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
134+
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
135+
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
136+
: "+r"(x[0]), "+r"(x[1])
137+
: "l"(xs));
138+
#else
139+
#pragma unroll
140+
for (int l = 0; l < ne; ++l) {
141+
x[l] = xs0[get_j(l)*stride + get_k(l)];
142+
}
143+
#endif // defined(INT8_MMA_AVAILABLE)
144+
}
89145
};
90146

91147
struct mma_int_C_I16J8 {

0 commit comments

Comments
 (0)