@@ -1497,7 +1497,7 @@ class vk_perf_logger {
14971497 }
14981498 if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
14991499 const uint64_t m = node->src[0]->ne[1];
1500- const uint64_t n = node->ne[1];
1500+ const uint64_t n = ( node->op == GGML_OP_MUL_MAT) ? node-> ne[1] : node->ne[2 ];
15011501 const uint64_t k = node->src[1]->ne[0];
15021502 const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
15031503 std::string name = ggml_op_name(node->op);
@@ -6572,23 +6572,36 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
65726572 return true;
65736573 }
65746574
6575+ // Quantization overhead is not worth it for small k
65756576 switch (device->vendor_id) {
65766577 case VK_VENDOR_ID_NVIDIA:
6578+ if (k <= 4096) {
6579+ return false;
6580+ }
6581+
65776582 switch (src0_type) {
6578- case GGML_TYPE_Q8_0:
65796583 case GGML_TYPE_MXFP4:
6584+ case GGML_TYPE_Q8_0:
65806585 return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
65816586 default:
65826587 return true;
65836588 }
65846589 case VK_VENDOR_ID_AMD:
6590+ if (k < 2048) {
6591+ return false;
6592+ }
6593+
65856594 switch (src0_type) {
65866595 case GGML_TYPE_Q8_0:
65876596 return device->architecture == vk_device_architecture::AMD_GCN;
65886597 default:
65896598 return true;
65906599 }
65916600 case VK_VENDOR_ID_INTEL:
6601+ if (k < 2048) {
6602+ return false;
6603+ }
6604+
65926605 switch (src0_type) {
65936606 // From tests on A770 Linux, may need more tuning
65946607 case GGML_TYPE_Q4_0:
@@ -6602,7 +6615,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
66026615 }
66036616
66046617 GGML_UNUSED(m);
6605- GGML_UNUSED(k);
66066618}
66076619
66086620static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
@@ -7280,7 +7292,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
72807292
72817293 const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
72827294
7283- bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ne21 >= 8 ;
7295+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
72847296
72857297 // Check for mmq first
72867298 vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
@@ -7550,7 +7562,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
75507562 const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
75517563
75527564 const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
7553- bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11 , ne10, src0->type);
7565+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12 , ne10, src0->type);
75547566
75557567 vk_pipeline to_fp16_vk_0 = nullptr;
75567568 vk_pipeline to_fp16_vk_1 = nullptr;
0 commit comments