Skip to content

Commit 94db33f

Browse files
committed
add q6_k mmvq
1 parent 2548b0d commit 94db33f

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3428,6 +3428,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
34283428
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34293429
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34303430
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3431+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", 4, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34313432
}
34323433
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
34333434
}
@@ -5169,6 +5170,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
51695170
case GGML_TYPE_Q3_K:
51705171
case GGML_TYPE_Q4_K:
51715172
case GGML_TYPE_Q5_K:
5173+
case GGML_TYPE_Q6_K:
51725174
break;
51735175
default:
51745176
return nullptr;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,47 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
293293
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 4));
294294
}
295295
#endif
296+
297+
#if defined(DATA_A_Q6_K)
298+
// 2-byte loads for Q6_K blocks (210 bytes)
299+
i32vec2 repack2(uint ib, uint iqs) {
300+
const uint ib_k = ib / 8;
301+
const uint iqs_k = (ib % 8) * 8 + iqs;
302+
303+
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
304+
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
305+
306+
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
307+
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
308+
309+
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
310+
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
311+
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
312+
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
313+
const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
314+
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
315+
const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
316+
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
317+
318+
return i32vec2(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
319+
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)));
320+
}
321+
322+
float get_d_scale(uint ib, uint iqs) {
323+
const uint ib_k = ib / 8;
324+
const uint iqs_k = (ib % 8) * 8 + iqs;
325+
return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);
326+
}
327+
328+
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
329+
int32_t q_sum = 0;
330+
331+
const i32vec2 qs_a = repack2(ib_a, iqs * 2);
332+
const float d_scale = get_d_scale(ib_a, iqs * 2);
333+
334+
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
335+
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
336+
337+
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
338+
}
339+
#endif

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ void process_shaders() {
664664

665665
// mul mat vec with integer dot product
666666
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
667-
if (is_legacy_quant(tname) || tname == "mxfp4" || tname == "q2_k" || tname == "q3_k" || tname == "q4_k" || tname == "q5_k") {
667+
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
668668
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
669669
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
670670
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@@ -1040,7 +1040,7 @@ void write_output_files() {
10401040

10411041
for (const std::string& btype : btypes) {
10421042
for (const auto& tname : type_names) {
1043-
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && tname != "q2_k" && tname != "q3_k" && tname != "q4_k" && tname != "q5_k") {
1043+
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
10441044
continue;
10451045
}
10461046
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";

0 commit comments

Comments
 (0)