Skip to content

Commit 6883c26

Browse files
committed
vulkan: Optimize mul_mat_vec p021 and nc shaders.
These shaders are used in attention calculations, and when the KV cache grows large they start to dominate the run time. For the nc shader (which is called with large 'k' dimension), use unrolling and vector loads. For the p021 shader (which is called with large 'm' and small 'k' dimensions), take advantage of grouped query attention to reuse loads from the A matrix for the whole group, and reduce the number of workgroups (too much overhead from tiny dispatches). Using subgroupAdd in the p021 shader also helps, use that conditionally.
1 parent 0071605 commit 6883c26

File tree

4 files changed

+192
-40
lines changed

4 files changed

+192
-40
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+31-5
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class vk_perf_logger;
149149
static void ggml_vk_destroy_buffer(vk_buffer& buf);
150150

151151
static constexpr uint32_t mul_mat_vec_max_cols = 8;
152+
static constexpr uint32_t p021_max_gqa_ratio = 8;
152153

153154
enum vk_device_architecture {
154155
OTHER,
@@ -231,6 +232,7 @@ struct vk_device_struct {
231232
bool uma;
232233
bool prefer_host_memory;
233234
bool float_controls_rte_fp16;
235+
bool subgroup_add;
234236

235237
bool subgroup_size_control;
236238
uint32_t subgroup_min_size;
@@ -277,7 +279,7 @@ struct vk_device_struct {
277279
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
278280
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
279281

280-
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
282+
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
281283
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
282284
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
283285
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
22652267

22662268
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
22672269

2268-
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2270+
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2271+
if (device->subgroup_add && device->subgroup_require_full_support) {
2272+
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
2273+
} else {
2274+
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
2275+
}
2276+
}
22692277
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
22702278

22712279
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -2471,13 +2479,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
24712479
vk::PhysicalDeviceDriverProperties driver_props;
24722480
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
24732481
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2482+
vk::PhysicalDeviceVulkan11Properties vk11_props;
24742483
vk::PhysicalDeviceVulkan12Properties vk12_props;
24752484
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
24762485

24772486
props2.pNext = &props3;
24782487
props3.pNext = &subgroup_props;
24792488
subgroup_props.pNext = &driver_props;
2480-
driver_props.pNext = &vk12_props;
2489+
driver_props.pNext = &vk11_props;
2490+
vk11_props.pNext = &vk12_props;
24812491

24822492
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
24832493

@@ -2541,6 +2551,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
25412551
}
25422552
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
25432553

2554+
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2555+
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2556+
25442557
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
25452558

25462559
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4627,9 +4640,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
46274640
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
46284641
const uint64_t d_sz = sizeof(float) * d_ne;
46294642

4643+
// With grouped query attention there are > 1 Q matrices per K, V matrix.
4644+
uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
4645+
if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
4646+
gqa_ratio = 1;
4647+
}
4648+
46304649
if (dryrun) {
46314650
// Request descriptor sets
4632-
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
4651+
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
46334652
return;
46344653
}
46354654

@@ -4653,8 +4672,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
46534672

46544673
// compute
46554674
const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
4675+
4676+
uint32_t workgroups_z = (uint32_t)ne12;
4677+
// When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4678+
if (gqa_ratio > 1) {
4679+
workgroups_z /= gqa_ratio;
4680+
}
4681+
46564682
ggml_vk_sync_buffers(subctx);
4657-
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
4683+
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
46584684
}
46594685

46604686
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp

+55-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1212
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
1313
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
1414

15+
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
16+
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
17+
1518
layout (push_constant) uniform parameter
1619
{
1720
uint ncols_x;
@@ -37,25 +40,66 @@ void main() {
3740

3841
const uint idst = channel*nrows_dst + row_dst;
3942

40-
tmp[tid] = 0.0f;
43+
FLOAT_TYPE temp = 0.0f;
4144

42-
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
43-
const uint col_x = col_x0 + tid;
45+
// Detect alignment for vector loads
46+
bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
4447

45-
if (col_x >= p.ncols_x) {
46-
break;
47-
}
48+
for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
49+
50+
// Unroll 2x and do vec4 loads if aligned
51+
const uint unroll_count = 2;
52+
if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
53+
[[unroll]] for (uint i = 0; i < unroll_count; ++i) {
54+
const uint col_x = col_x0 + 4*tid;
55+
56+
const uint row_y = col_x;
57+
58+
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
59+
const uint iy = channel*nrows_y + row_y;
60+
61+
const vec4 av4 = vec4(data_a_v4[ix / 4]);
62+
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
63+
64+
temp += dot(av4, bv4);
65+
66+
col_x0 += 4*BLOCK_SIZE;
67+
}
68+
// do vec4 loads if aligned
69+
} else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
70+
const uint col_x = col_x0 + 4*tid;
4871

49-
const uint row_y = col_x;
72+
const uint row_y = col_x;
5073

51-
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
52-
const uint iy = channel*nrows_y + row_y;
74+
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
75+
const uint iy = channel*nrows_y + row_y;
5376

54-
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
77+
const vec4 av4 = vec4(data_a_v4[ix / 4]);
78+
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
5579

56-
tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
80+
temp += dot(av4, bv4);
81+
82+
col_x0 += 4*BLOCK_SIZE;
83+
} else {
84+
const uint col_x = col_x0 + tid;
85+
if (col_x >= p.ncols_x) {
86+
break;
87+
}
88+
89+
const uint row_y = col_x;
90+
91+
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
92+
const uint iy = channel*nrows_y + row_y;
93+
94+
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
95+
96+
temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
97+
col_x0 += BLOCK_SIZE;
98+
}
5799
}
58100

101+
tmp[tid] = temp;
102+
59103
// sum up partial sums and write back result
60104
barrier();
61105
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp

+103-22
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,25 @@
22

33
#extension GL_EXT_control_flow_attributes : enable
44
#extension GL_EXT_shader_16bit_storage : require
5+
#if USE_SUBGROUP_ADD
6+
#extension GL_KHR_shader_subgroup_arithmetic : enable
7+
#endif
58

6-
#define BLOCK_SIZE 32
79
#define FLOAT_TYPE float
810

9-
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
11+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1012

1113
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1214
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
1315
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
1416

17+
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
18+
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
19+
20+
layout(constant_id = 0) const int BLOCK_SIZE = 32;
21+
// gqa_ratio is in the range [1,8]
22+
layout(constant_id = 1) const uint gqa_ratio = 1;
23+
1524
layout (push_constant) uniform parameter
1625
{
1726
uint ncols_x;
@@ -22,52 +31,124 @@ layout (push_constant) uniform parameter
2231
uint d_offset;
2332
} p;
2433

25-
shared FLOAT_TYPE tmp[BLOCK_SIZE];
34+
#if !USE_SUBGROUP_ADD
35+
shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
36+
#endif
2637

2738
void main() {
2839
const uint tid = gl_LocalInvocationID.x;
2940
const uint row_x = gl_GlobalInvocationID.y;
30-
const uint channel = gl_GlobalInvocationID.z;
31-
const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
41+
42+
uint channel, channel_x;
43+
44+
// When gqa_ratio > 1, each invocation does multiple rows.
45+
// The row in the A matrix is starting from channel / gqa_ratio and the
46+
// rows in the B matrix are [channel, channel+gqa_ratio).
47+
// When gpa_ratio is 1, each invocation does one row.
48+
if (gqa_ratio > 1) {
49+
channel_x = gl_GlobalInvocationID.z;
50+
channel = channel_x * gqa_ratio;
51+
} else {
52+
channel = gl_GlobalInvocationID.z;
53+
channel_x = channel / (p.nchannels_y / p.nchannels_x);;
54+
}
3255

3356
const uint nrows_y = p.ncols_x;
3457
const uint nrows_dst = p.nrows_x;
3558
const uint row_dst = row_x;
3659

37-
tmp[tid] = FLOAT_TYPE(0.0f);
60+
FLOAT_TYPE temp[8];
61+
[[unroll]] for (uint i = 0; i < 8; ++i) {
62+
temp[i] = FLOAT_TYPE(0.0f);
63+
}
64+
65+
// Detect alignment for vector loads
66+
bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
3867

3968
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
40-
const uint col_x = col_x0 + tid;
4169

42-
if (col_x >= p.ncols_x) {
43-
break;
44-
}
70+
// Use vec4 loads if aligned
71+
if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
4572

46-
// x is transposed and permuted
47-
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
48-
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
73+
uint col_x = col_x0 + 4*tid;
74+
const uint row_y = col_x;
4975

50-
const uint row_y = col_x;
76+
// x is transposed and permuted
77+
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
78+
const vec4 av4 = vec4(data_a_v4[ix / 4]);
5179

52-
// y is not transposed but permuted
53-
const uint iy = channel*nrows_y + row_y;
80+
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
81+
// y is not transposed but permuted
82+
const uint iy = (channel + c)*nrows_y + row_y;
5483

55-
tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
56-
}
84+
vec4 bv4 = data_b_v4[iy / 4];
85+
temp[c] += dot(av4, bv4);
86+
}
87+
88+
col_x0 += 3*BLOCK_SIZE;
89+
} else {
90+
const uint col_x = col_x0 + tid;
91+
92+
if (col_x >= p.ncols_x) {
93+
break;
94+
}
5795

58-
// dst is not transposed and not permuted
59-
const uint idst = channel*nrows_dst + row_dst;
96+
// x is transposed and permuted
97+
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
98+
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
6099

100+
const uint row_y = col_x;
101+
102+
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
103+
// y is not transposed but permuted
104+
const uint iy = (channel + c)*nrows_y + row_y;
105+
106+
temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
107+
}
108+
}
109+
}
110+
111+
#if USE_SUBGROUP_ADD
112+
// reduce vec4 at a time
113+
vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
114+
t = subgroupAdd(t);
115+
temp[0] = t[0];
116+
temp[1] = t[1];
117+
temp[2] = t[2];
118+
temp[3] = t[3];
119+
if (gqa_ratio > 4) {
120+
t = vec4(temp[4], temp[5], temp[6], temp[7]);
121+
t = subgroupAdd(t);
122+
temp[4] = t[0];
123+
temp[5] = t[1];
124+
temp[6] = t[2];
125+
temp[7] = t[3];
126+
}
127+
#else
128+
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
129+
tmp[c][tid] = temp[c];
130+
}
61131
// sum up partial sums and write back result
62132
barrier();
63133
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
64134
if (tid < s) {
65-
tmp[tid] += tmp[tid + s];
135+
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
136+
temp[c] += tmp[c][tid + s];
137+
tmp[c][tid] = temp[c];
138+
}
66139
}
67140
barrier();
68141
}
142+
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
143+
temp[c] = tmp[c][tid];
144+
}
145+
#endif
69146

70147
if (tid == 0) {
71-
dst[idst] = tmp[0];
148+
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
149+
// dst is not transposed and not permuted
150+
const uint idst = (channel + c)*nrows_dst + row_dst;
151+
dst[idst] = temp[c];
152+
}
72153
}
73154
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,9 @@ void process_shaders() {
426426
}
427427
}
428428

429-
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
430-
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
429+
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
430+
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
431+
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
431432

432433
// Norms
433434
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

0 commit comments

Comments
 (0)