Skip to content

Commit 2548b0d

Browse files
committed
add q4_k and q5_k mmvq
1 parent bcd47ab commit 2548b0d

File tree

5 files changed

+56
-7
lines changed

5 files changed

+56
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3426,6 +3426,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
34263426

34273427
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34283428
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3429+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3430+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34293431
}
34303432
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
34313433
}
@@ -5165,6 +5167,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
51655167
case GGML_TYPE_MXFP4:
51665168
case GGML_TYPE_Q2_K:
51675169
case GGML_TYPE_Q3_K:
5170+
case GGML_TYPE_Q4_K:
5171+
case GGML_TYPE_Q5_K:
51685172
break;
51695173
default:
51705174
return nullptr;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
4040

4141
uint ibi = first_row*p.ncols;
4242
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
43-
const uint a_block_idx = (ibi + col)/32 + a_offset;
43+
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
4444
ibi += p.ncols;
4545

4646
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
@@ -52,7 +52,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
5252
const uint tid = gl_LocalInvocationID.x;
5353

5454
get_offsets(a_offset, b_offset, d_offset);
55-
a_offset /= QUANT_K;
55+
a_offset /= QUANT_K_Q8_1;
5656
b_offset /= QUANT_K_Q8_1;
5757

5858
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,53 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
243243

244244
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
245245
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
246-
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
247-
return FLOAT_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
246+
i32vec2 repack2(uint ib, uint iqs) {
247+
const uint ib_k = ib / 8;
248+
const uint iqs_k = (ib % 8) * 8 + iqs;
249+
250+
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
251+
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
252+
253+
#if defined(DATA_A_Q4_K)
254+
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
255+
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
256+
257+
return i32vec2(vals0, vals1);
258+
#else // defined(DATA_A_Q5_K)
259+
const uint qh_idx = iqs;
260+
const uint qh_shift = iqs_k / 8;
261+
262+
return i32vec2(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) |
263+
(((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4),
264+
((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |
265+
(((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4));
266+
#endif
267+
}
268+
269+
vec2 get_dm_scale(uint ib, uint iqs) {
270+
const uint ib_k = ib / 8;
271+
const uint iqs_k = (ib % 8) * 8 + iqs;
272+
const uint is = iqs_k / 8;
273+
u8vec2 scale_dm;
274+
if (is < 4) {
275+
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
276+
} else {
277+
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
278+
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
279+
}
280+
281+
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
282+
}
283+
284+
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
285+
int32_t q_sum = 0;
286+
287+
const i32vec2 qs_a = repack2(ib_a, iqs * 2);
288+
const vec2 dm_scale = get_dm_scale(ib_a, iqs * 2);
289+
290+
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
291+
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
292+
293+
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 4));
248294
}
249295
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
322322
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
323323
#endif
324324

325-
326325
if (iqs == 0) {
327326
// Scale index
328327
const uint is = iqs_k / 8;

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ void process_shaders() {
664664

665665
// mul mat vec with integer dot product
666666
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
667-
if (is_legacy_quant(tname) || tname == "mxfp4" || tname == "q2_k" || tname == "q3_k") {
667+
if (is_legacy_quant(tname) || tname == "mxfp4" || tname == "q2_k" || tname == "q3_k" || tname == "q4_k" || tname == "q5_k") {
668668
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
669669
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
670670
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@@ -1040,7 +1040,7 @@ void write_output_files() {
10401040

10411041
for (const std::string& btype : btypes) {
10421042
for (const auto& tname : type_names) {
1043-
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && tname != "q2_k" && tname != "q3_k") {
1043+
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && tname != "q2_k" && tname != "q3_k" && tname != "q4_k" && tname != "q5_k") {
10441044
continue;
10451045
}
10461046
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";

0 commit comments

Comments
 (0)