Skip to content

Commit 480d6d6

Browse files
committed
iq1_m: adapt to CUDA refactoring
1 parent 3d9c21f commit 480d6d6

File tree

3 files changed

+105
-0
lines changed

3 files changed

+105
-0
lines changed

ggml-cuda/convert.cu

+46
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,42 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
501501

502502
}
503503

504+
typedef union {
505+
half f16;
506+
uint16_t u16;
507+
} iq1m_scale_t;
508+
509+
template<typename dst_t>
510+
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
511+
512+
const int i = blockIdx.x;
513+
const block_iq1_m * x = (const block_iq1_m *) vx;
514+
515+
const int tid = threadIdx.x;
516+
#if QK_K == 256
517+
const int il = tid/8; // 0...3
518+
const int ib = tid%8; // 0...7
519+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
520+
const uint16_t * sc = (const uint16_t *)x[i].scales;
521+
iq1m_scale_t scale;
522+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
523+
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
524+
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
525+
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
526+
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
527+
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
528+
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
529+
grid32[0] &= 0x0f0f0f0f;
530+
for (int j = 0; j < 8; ++j) {
531+
y[j] = d * (q[j] + delta);
532+
}
533+
#else
534+
assert(false);
535+
#endif
536+
537+
}
538+
539+
504540
template<typename dst_t>
505541
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
506542

@@ -658,6 +694,12 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
658694
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
659695
}
660696

697+
template<typename dst_t>
698+
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
699+
const int nb = k / QK_K;
700+
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
701+
}
702+
661703
template<typename dst_t>
662704
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
663705
const int nb = (k + QK_K - 1) / QK_K;
@@ -724,6 +766,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
724766
return dequantize_row_iq3_xxs_cuda;
725767
case GGML_TYPE_IQ1_S:
726768
return dequantize_row_iq1_s_cuda;
769+
case GGML_TYPE_IQ1_M:
770+
return dequantize_row_iq1_m_cuda;
727771
case GGML_TYPE_IQ4_NL:
728772
return dequantize_row_iq4_nl_cuda;
729773
case GGML_TYPE_IQ4_XS:
@@ -769,6 +813,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
769813
return dequantize_row_iq3_xxs_cuda;
770814
case GGML_TYPE_IQ1_S:
771815
return dequantize_row_iq1_s_cuda;
816+
case GGML_TYPE_IQ1_M:
817+
return dequantize_row_iq1_m_cuda;
772818
case GGML_TYPE_IQ4_NL:
773819
return dequantize_row_iq4_nl_cuda;
774820
case GGML_TYPE_IQ4_XS:

ggml-cuda/mmvq.cu

+11
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(
282282
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
283283
}
284284

285+
static void mul_mat_vec_iq1_m_q8_1_cuda(
286+
const void * vx, const void * vy, float * dst,
287+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
288+
289+
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
290+
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
291+
}
292+
285293
static void mul_mat_vec_iq4_nl_q8_1_cuda(
286294
const void * vx, const void * vy, float * dst,
287295
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -373,6 +381,9 @@ void ggml_cuda_op_mul_mat_vec_q(
373381
case GGML_TYPE_IQ1_S:
374382
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
375383
break;
384+
case GGML_TYPE_IQ1_M:
385+
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
386+
break;
376387
case GGML_TYPE_IQ4_NL:
377388
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
378389
break;

ggml-cuda/vecdotq.cuh

+48
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,54 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
11641164
#endif
11651165
}
11661166

1167+
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
1168+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1169+
#if QK_K == 256
1170+
const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
1171+
1172+
const int ib32 = iqs;
1173+
int sumi[2] = {0, 0};
1174+
float sumf[2] = {0.f, 0.f};
1175+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1176+
const int * q8 = (const int *)bq8_1[ib32].qs;
1177+
for (int l = 0; l < 4; ++l) {
1178+
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
1179+
int grid0 = grid[0] & 0x0f0f0f0f;
1180+
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
1181+
sumi[l/2] = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi[l/2]));
1182+
const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
1183+
const int sumy = __dp4a(q8[2*l+1], 0x01010101, __dp4a(q8[2*l+0], 0x01010101, 0));
1184+
sumf[l/2] += delta*sumy;
1185+
}
1186+
#else
1187+
const int8_t * q8 = bq8_1[ib32].qs;
1188+
for (int l = 0; l < 4; ++l) {
1189+
const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
1190+
int sumy = 0;
1191+
for (int j = 0; j < 4; ++j) {
1192+
sumi[l/2] += q8[j] * (grid[j] & 0xf) + q8[j+4] * (grid[j] >> 4);
1193+
sumy += q8[j] + q8[j+4];
1194+
}
1195+
const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
1196+
sumf[l/2] += delta*sumy;
1197+
q8 += 8;
1198+
}
1199+
#endif
1200+
typedef union {
1201+
half f16;
1202+
uint16_t u16;
1203+
} iq1m_scale_t;
1204+
iq1m_scale_t scale;
1205+
const uint16_t * sc = (const uint16_t *)bq1->scales;
1206+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
1207+
const float d = (float)scale.f16 * __low2float (bq8_1[ib32].ds);
1208+
return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
1209+
#else
1210+
assert(false);
1211+
return 0.f;
1212+
#endif
1213+
}
1214+
11671215
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
11681216
static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,
11691217
int & val1, int & val2) {

0 commit comments

Comments
 (0)