Skip to content

Commit 434c6d3

Browse files
committed
device tuning
1 parent 1185533 commit 434c6d3

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ class vk_perf_logger {
14981498
}
14991499
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
15001500
const uint64_t m = node->src[0]->ne[1];
1501-
const uint64_t n = node->ne[1];
1501+
const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2];
15021502
const uint64_t k = node->src[1]->ne[0];
15031503
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
15041504
std::string name = ggml_op_name(node->op);
@@ -6581,23 +6581,36 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
65816581
return true;
65826582
}
65836583

6584+
// Quantization overhead is not worth it for small k
65846585
switch (device->vendor_id) {
65856586
case VK_VENDOR_ID_NVIDIA:
6587+
if (k <= 4096) {
6588+
return false;
6589+
}
6590+
65866591
switch (src0_type) {
6587-
case GGML_TYPE_Q8_0:
65886592
case GGML_TYPE_MXFP4:
6593+
case GGML_TYPE_Q8_0:
65896594
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
65906595
default:
65916596
return true;
65926597
}
65936598
case VK_VENDOR_ID_AMD:
6599+
if (k < 2048) {
6600+
return false;
6601+
}
6602+
65946603
switch (src0_type) {
65956604
case GGML_TYPE_Q8_0:
65966605
return device->architecture == vk_device_architecture::AMD_GCN;
65976606
default:
65986607
return true;
65996608
}
66006609
case VK_VENDOR_ID_INTEL:
6610+
if (k < 2048) {
6611+
return false;
6612+
}
6613+
66016614
switch (src0_type) {
66026615
// From tests on A770 Linux, may need more tuning
66036616
case GGML_TYPE_Q4_0:
@@ -6611,7 +6624,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
66116624
}
66126625

66136626
GGML_UNUSED(m);
6614-
GGML_UNUSED(k);
66156627
}
66166628

66176629
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -7288,7 +7300,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
72887300

72897301
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
72907302

7291-
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ne21 >= 8;
7303+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
72927304

72937305
// Check for mmq first
72947306
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
@@ -7559,7 +7571,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
75597571
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
75607572

75617573
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
7562-
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
7574+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
75637575

75647576
vk_pipeline to_fp16_vk_0 = nullptr;
75657577
vk_pipeline to_fp16_vk_1 = nullptr;

0 commit comments

Comments
 (0)