Skip to content

Commit d6f012f

Browse files
committed
enable MUL_MAT_ID mmvq support
1 parent 9c36294 commit d6f012f

File tree

2 files changed

+108
-20
lines changed

2 files changed

+108
-20
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 106 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ struct vk_device_struct {
572572
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
573573

574574
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
575+
vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_COUNT];
575576

576577
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
577578
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
@@ -3463,6 +3464,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
34633464
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
34643465
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 5, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
34653466

3467+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3468+
if (device->integer_dot_product) {
3469+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", mul_mat_vec_id_q4_0_q8_1_f32_len, mul_mat_vec_id_q4_0_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3470+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", mul_mat_vec_id_q4_1_q8_1_f32_len, mul_mat_vec_id_q4_1_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3471+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", mul_mat_vec_id_q5_0_q8_1_f32_len, mul_mat_vec_id_q5_0_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3472+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", mul_mat_vec_id_q5_1_q8_1_f32_len, mul_mat_vec_id_q5_1_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3473+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", mul_mat_vec_id_q8_0_q8_1_f32_len, mul_mat_vec_id_q8_0_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {device->subgroup_size, 1*rm_stdq_int}, 1, true);
3474+
3475+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", mul_mat_vec_id_mxfp4_q8_1_f32_len, mul_mat_vec_id_mxfp4_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {device->subgroup_size, 2*rm_stdq_int}, 1, true);
3476+
3477+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", mul_mat_vec_id_q2_k_q8_1_f32_len, mul_mat_vec_id_q2_k_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {device->subgroup_size, 2*rm_kq_int}, 1, true);
3478+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", mul_mat_vec_id_q3_k_q8_1_f32_len, mul_mat_vec_id_q3_k_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {device->subgroup_size, 1*rm_kq_int}, 1, true);
3479+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", mul_mat_vec_id_q4_k_q8_1_f32_len, mul_mat_vec_id_q4_k_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {device->subgroup_size, 1*rm_kq_int}, 1, true);
3480+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", mul_mat_vec_id_q5_k_q8_1_f32_len, mul_mat_vec_id_q5_k_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {device->subgroup_size, 1*rm_kq_int}, 1, true);
3481+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", mul_mat_vec_id_q6_k_q8_1_f32_len, mul_mat_vec_id_q6_k_q8_1_f32_data, "main", 5, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {device->subgroup_size, 1*rm_kq_int}, 1, true);
3482+
}
3483+
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
3484+
34663485
// dequant shaders
34673486
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
34683487
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -5317,6 +5336,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
53175336

53185337
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
53195338
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
5339+
5340+
if (b_type == GGML_TYPE_Q8_1) {
5341+
switch (a_type) {
5342+
case GGML_TYPE_Q4_0:
5343+
case GGML_TYPE_Q4_1:
5344+
case GGML_TYPE_Q5_0:
5345+
case GGML_TYPE_Q5_1:
5346+
case GGML_TYPE_Q8_0:
5347+
case GGML_TYPE_MXFP4:
5348+
case GGML_TYPE_Q2_K:
5349+
case GGML_TYPE_Q3_K:
5350+
case GGML_TYPE_Q4_K:
5351+
case GGML_TYPE_Q5_K:
5352+
case GGML_TYPE_Q6_K:
5353+
break;
5354+
default:
5355+
return nullptr;
5356+
}
5357+
5358+
return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[a_type];
5359+
}
5360+
53205361
GGML_ASSERT(b_type == GGML_TYPE_F32);
53215362

53225363
switch (a_type) {
@@ -6497,6 +6538,11 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
64976538
return false;
64986539
}
64996540

6541+
// General issue with q3_k and q6_k
6542+
if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
6543+
return false;
6544+
}
6545+
65006546
// MMVQ is generally good for batches
65016547
if (n > 1) {
65026548
return true;
@@ -6506,6 +6552,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
65066552
case VK_VENDOR_ID_NVIDIA:
65076553
switch (src0_type) {
65086554
case GGML_TYPE_Q8_0:
6555+
case GGML_TYPE_MXFP4:
65096556
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
65106557
default:
65116558
return true;
@@ -7479,12 +7526,41 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
74797526
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
74807527

74817528
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
7529+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
7530+
7531+
vk_pipeline to_fp16_vk_0 = nullptr;
7532+
vk_pipeline to_fp16_vk_1 = nullptr;
7533+
if (x_non_contig) {
7534+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
7535+
}
7536+
if (y_non_contig) {
7537+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
7538+
} else {
7539+
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
7540+
}
7541+
7542+
// Check for mmq first
7543+
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1) : nullptr;
7544+
vk_pipeline to_q8_1 = nullptr;
7545+
7546+
if (dmmv == nullptr) {
7547+
// Fall back to f16 dequant mul mat
7548+
dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
7549+
quantize_y = false;
7550+
}
7551+
7552+
if (quantize_y) {
7553+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
7554+
}
74827555

74837556
const bool qx_needs_dequant = x_non_contig;
7484-
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
7557+
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
74857558

74867559
// Not implemented
74877560
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
7561+
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7562+
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7563+
GGML_ASSERT(dmmv != nullptr);
74887564

74897565
const uint64_t x_ne = ne01 * ne00;
74907566
const uint64_t y_ne = ne11 * ne10;
@@ -7493,28 +7569,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
74937569
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
74947570
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
74957571
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
7496-
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
7572+
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
74977573
const uint64_t ids_sz = nbi2;
74987574
const uint64_t d_sz = sizeof(float) * d_ne;
74997575

7500-
vk_pipeline to_fp16_vk_0 = nullptr;
7501-
vk_pipeline to_fp16_vk_1 = nullptr;
7502-
if (x_non_contig) {
7503-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
7504-
}
7505-
if (y_non_contig) {
7506-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
7507-
} else {
7508-
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
7509-
}
7510-
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
7511-
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7512-
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7513-
GGML_ASSERT(dmmv != nullptr);
7514-
75157576
if (dryrun) {
75167577
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
7517-
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
7578+
uint64_t y_sz_upd = y_sz * ne12 * ne13;
7579+
if (quantize_y) {
7580+
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
7581+
}
75187582
if (
75197583
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
75207584
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
@@ -7523,7 +7587,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
75237587
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
75247588
ctx->prealloc_size_x = x_sz_upd;
75257589
}
7526-
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
7590+
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
75277591
ctx->prealloc_size_y = y_sz_upd;
75287592
}
75297593

@@ -7534,6 +7598,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
75347598
if (qy_needs_dequant) {
75357599
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
75367600
}
7601+
if (quantize_y) {
7602+
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
7603+
}
75377604
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
75387605
return;
75397606
}
@@ -7581,6 +7648,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
75817648
}
75827649
if (qy_needs_dequant) {
75837650
d_Y = ctx->prealloc_y;
7651+
} else if (quantize_y) {
7652+
d_Y = ctx->prealloc_y;
7653+
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
75847654
} else {
75857655
d_Y = d_Qy;
75867656
y_buf_offset = qy_buf_offset;
@@ -7609,6 +7679,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76097679
ctx->prealloc_y_last_tensor_used = src1;
76107680
}
76117681
}
7682+
if (quantize_y) {
7683+
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
7684+
ctx->prealloc_y_last_tensor_used != src1) {
7685+
if (ctx->prealloc_y_need_sync) {
7686+
ggml_vk_sync_buffers(ctx, subctx);
7687+
}
7688+
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
7689+
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
7690+
ctx->prealloc_y_last_tensor_used = src1;
7691+
}
7692+
}
76127693

76137694
uint32_t stride_batch_y = ne10*ne11;
76147695

@@ -7649,6 +7730,11 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76497730
}
76507731
}
76517732

7733+
uint32_t y_sz_total = y_sz * ne12 * ne13;
7734+
if (quantize_y) {
7735+
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
7736+
}
7737+
76527738
// compute
76537739
const vk_mat_vec_id_push_constants pc = {
76547740
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
@@ -7661,7 +7747,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76617747
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
76627748
{
76637749
vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
7664-
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 },
7750+
vk_subbuffer{ d_Y, y_buf_offset, y_sz_total },
76657751
vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23},
76667752
vk_subbuffer{ d_B, b_buf_offset, b_sz },
76677753
vk_subbuffer{ d_ids, ids_buf_offset, ids_sz },

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,8 @@ void process_shaders() {
668668
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
669669
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
670670
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
671+
672+
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
671673
}
672674
#endif
673675

0 commit comments

Comments
 (0)