@@ -149,6 +149,7 @@ class vk_perf_logger;
149
149
static void ggml_vk_destroy_buffer(vk_buffer& buf);
150
150
151
151
static constexpr uint32_t mul_mat_vec_max_cols = 8;
152
+ static constexpr uint32_t p021_max_gqa_ratio = 8;
152
153
153
154
enum vk_device_architecture {
154
155
OTHER,
@@ -231,6 +232,7 @@ struct vk_device_struct {
231
232
bool uma;
232
233
bool prefer_host_memory;
233
234
bool float_controls_rte_fp16;
235
+ bool subgroup_add;
234
236
235
237
bool subgroup_size_control;
236
238
uint32_t subgroup_min_size;
@@ -277,7 +279,7 @@ struct vk_device_struct {
277
279
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
278
280
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
279
281
280
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
282
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio] ;
281
283
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
282
284
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
283
285
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
2265
2267
2266
2268
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2267
2269
2268
- ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32" , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
2270
+ for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2271
+ if (device->subgroup_add && device->subgroup_require_full_support) {
2272
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
2273
+ } else {
2274
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
2275
+ }
2276
+ }
2269
2277
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2270
2278
2271
2279
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -2471,13 +2479,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
2471
2479
vk::PhysicalDeviceDriverProperties driver_props;
2472
2480
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2473
2481
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2482
+ vk::PhysicalDeviceVulkan11Properties vk11_props;
2474
2483
vk::PhysicalDeviceVulkan12Properties vk12_props;
2475
2484
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2476
2485
2477
2486
props2.pNext = &props3;
2478
2487
props3.pNext = &subgroup_props;
2479
2488
subgroup_props.pNext = &driver_props;
2480
- driver_props.pNext = &vk12_props;
2489
+ driver_props.pNext = &vk11_props;
2490
+ vk11_props.pNext = &vk12_props;
2481
2491
2482
2492
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2483
2493
@@ -2541,6 +2551,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2541
2551
}
2542
2552
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
2543
2553
2554
+ device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2555
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2556
+
2544
2557
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2545
2558
2546
2559
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4627,9 +4640,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4627
4640
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4628
4641
const uint64_t d_sz = sizeof(float) * d_ne;
4629
4642
4643
+ // With grouped query attention there are > 1 Q matrices per K, V matrix.
4644
+ uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
4645
+ if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
4646
+ gqa_ratio = 1;
4647
+ }
4648
+
4630
4649
if (dryrun) {
4631
4650
// Request descriptor sets
4632
- ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , 1 );
4651
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1] , 1);
4633
4652
return;
4634
4653
}
4635
4654
@@ -4653,8 +4672,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4653
4672
4654
4673
// compute
4655
4674
const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
4675
+
4676
+ uint32_t workgroups_z = (uint32_t)ne12;
4677
+ // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4678
+ if (gqa_ratio > 1) {
4679
+ workgroups_z /= gqa_ratio;
4680
+ }
4681
+
4656
4682
ggml_vk_sync_buffers(subctx);
4657
- ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_mul_mat_vec_p021_f16_f32 , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof (uint32_t ), &pc, { 1 , (uint32_t )ne01, ( uint32_t )ne12 });
4683
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1] , { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
4658
4684
}
4659
4685
4660
4686
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
0 commit comments