@@ -369,6 +369,10 @@ struct vk_device_struct {
369369 bool subgroup_add;
370370 bool subgroup_shuffle;
371371
372+ bool atomic_float_add;
373+ bool add_rms_fusion;
374+ uint32_t atomic_binding_alignment;
375+
372376 bool integer_dot_product;
373377
374378 bool subgroup_size_control;
@@ -448,6 +452,8 @@ struct vk_device_struct {
448452 vk_pipeline pipeline_mul_norepeat[2][2][2];
449453 vk_pipeline pipeline_div[2][2][2];
450454 vk_pipeline pipeline_div_norepeat[2][2][2];
455+ vk_pipeline pipeline_add_rms[2][2][2];
456+ vk_pipeline pipeline_add_rms_norepeat[2][2][2];
451457
452458 vk_pipeline pipeline_add_id_f32;
453459
@@ -1144,6 +1150,12 @@ class vk_perf_logger {
11441150 timings[name].push_back(time);
11451151 return;
11461152 }
1153+ if (node->op == GGML_OP_RMS_NORM) {
1154+ std::string name = ggml_op_name(node->op);
1155+ name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
1156+ timings[name].push_back(time);
1157+ return;
1158+ }
11471159 timings[ggml_op_name(node->op)].push_back(time);
11481160 }
11491161 private:
@@ -1158,10 +1170,13 @@ struct ggml_backend_vk_context {
11581170
11591171 size_t semaphore_idx, event_idx;
11601172 ggml_vk_garbage_collector gc;
1161- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
1162- vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
1173+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset ;
1174+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add ;
11631175 vk::Fence fence, almost_ready_fence;
11641176 bool almost_ready_fence_pending {};
1177+ // Set before op_add and unset after op_rms_norm to indicate that the add should
1178+ // use atomics to accumulate the square of the vector components
1179+ bool do_add_rms_atomic;
11651180
11661181 vk_buffer buffer_pool[MAX_VK_BUFFERS];
11671182
@@ -2924,8 +2939,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29242939
29252940 ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29262941 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2927- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2928- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2942+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true );
2943+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true );
29292944 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29302945 ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29312946
@@ -2995,20 +3010,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
29953010 };
29963011
29973012 bool rte = device->float_controls_rte_fp16;
2998- #define CREATE_BINARY(name, namemod, spec) \
3013+ #define CREATE_BINARY(name, namemod, spec, bindings ) \
29993014 for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
30003015 ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
30013016 #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
3002- "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3003-
3004- CREATE_BINARY(add, , {0})
3005- CREATE_BINARY(add, _norepeat, {1})
3006- CREATE_BINARY(sub, , {0})
3007- CREATE_BINARY(sub, _norepeat, {1})
3008- CREATE_BINARY(mul, , {0})
3009- CREATE_BINARY(mul, _norepeat, {1})
3010- CREATE_BINARY(div, , {0})
3011- CREATE_BINARY(div, _norepeat, {1})
3017+ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3018+
3019+ CREATE_BINARY(add, , {0}, 4)
3020+ CREATE_BINARY(add, _norepeat, {1}, 4)
3021+ CREATE_BINARY(sub, , {0}, 3)
3022+ CREATE_BINARY(sub, _norepeat, {1}, 3)
3023+ CREATE_BINARY(mul, , {0}, 3)
3024+ CREATE_BINARY(mul, _norepeat, {1}, 3)
3025+ CREATE_BINARY(div, , {0}, 3)
3026+ CREATE_BINARY(div, _norepeat, {1}, 3)
3027+ CREATE_BINARY(add_rms, , {0}, 4)
3028+ CREATE_BINARY(add_rms, _norepeat, {1}, 4)
30123029#undef CREATE_BINARY
30133030
30143031 ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
@@ -3281,6 +3298,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
32813298 device->coopmat_support = false;
32823299 device->integer_dot_product = false;
32833300 bool bfloat16_support = false;
3301+ bool atomic_float_support = false;
32843302
32853303 for (const auto& properties : ext_props) {
32863304 if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3320,6 +3338,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
33203338 !getenv("GGML_VK_DISABLE_BFLOAT16")) {
33213339 bfloat16_support = true;
33223340#endif
3341+ } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3342+ atomic_float_support = true;
33233343 }
33243344 }
33253345
@@ -3536,6 +3556,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
35363556 device_extensions.push_back("VK_KHR_shader_integer_dot_product");
35373557 }
35383558
3559+ VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3560+ atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3561+ if (atomic_float_support) {
3562+ last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3563+ last_struct = (VkBaseOutStructure *)&atomic_float_features;
3564+ device_extensions.push_back("VK_EXT_shader_atomic_float");
3565+ }
3566+
35393567 vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
35403568
35413569 device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3547,6 +3575,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
35473575#endif
35483576
35493577 device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3578+ device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
35503579
35513580 if (device->subgroup_size_control) {
35523581 device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
@@ -3861,6 +3890,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
38613890
38623891 device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
38633892
3893+ device->add_rms_fusion = !device->disable_fusion &&
3894+ device->subgroup_add &&
3895+ device->atomic_float_add;
3896+ device->atomic_binding_alignment =
3897+ std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
3898+
38643899 return device;
38653900 }
38663901
@@ -6892,8 +6927,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
68926927 switch (op) {
68936928 case GGML_OP_ADD:
68946929 {
6895- auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6896- return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6930+ if (ctx->do_add_rms_atomic) {
6931+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
6932+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6933+ } else {
6934+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6935+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6936+ }
68976937 }
68986938 case GGML_OP_SUB:
68996939 {
@@ -7494,7 +7534,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
74947534 }
74957535 } break;
74967536 case GGML_OP_RMS_NORM:
7497- elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7537+ if (ctx->do_add_rms_atomic) {
7538+ // Run one element per thread, 128 threads per workgroup
7539+ elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7540+ } else {
7541+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7542+ }
74987543 break;
74997544
75007545 case GGML_OP_SUM:
@@ -7642,7 +7687,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
76427687 }
76437688 }
76447689
7645- if (op == GGML_OP_GLU) {
7690+ if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7691+ vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7692+ size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7693+ ggml_vk_sync_buffers(subctx);
7694+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7695+ { vk_subbuffer{ d_X, x_buf_offset, x_sz },
7696+ vk_subbuffer{ d_Y, y_buf_offset, y_sz },
7697+ vk_subbuffer{ d_D, d_buf_offset, d_sz },
7698+ vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
7699+ }, pc, elements);
7700+ } else if (op == GGML_OP_GLU) {
76467701 // Empty src1 is possible in glu, but the shader needs a buffer
76477702 vk_subbuffer subbuf_y;
76487703 if (use_src1) {
@@ -7750,7 +7805,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
77507805 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
77517806 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
77527807 0,
7753- 0.0f, 0.0f, 0 ,
7808+ 0.0f, 0.0f, ctx->do_add_rms_atomic ,
77547809 }, dryrun);
77557810}
77567811
@@ -8213,8 +8268,13 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
82138268 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
82148269 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
82158270 0,
8216- op_params[0], 0.0f, 0 ,
8271+ op_params[0], 0.0f, ctx->do_add_rms_atomic ,
82178272 }, dryrun);
8273+
8274+ if (ctx->do_add_rms_atomic) {
8275+ ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment;
8276+ ctx->do_add_rms_atomic = false;
8277+ }
82188278}
82198279
82208280static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9492,6 +9552,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
94929552 }
94939553 ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
94949554 }
9555+ if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) {
9556+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")");
9557+ // Resize buffer
9558+ if (ctx->prealloc_atomic_add != nullptr) {
9559+ ggml_vk_destroy_buffer(ctx->prealloc_atomic_add);
9560+ }
9561+ ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add);
9562+ }
94959563}
94969564
94979565static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9547,10 +9615,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
95479615 return false;
95489616 }
95499617 break;
9618+ case GGML_OP_ADD:
9619+ if (node_idx + 1 < cgraph->n_nodes &&
9620+ cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
9621+ cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
9622+ ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9623+ ctx->device->add_rms_fusion) {
9624+ if (dryrun) {
9625+ ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment;
9626+ }
9627+ ctx->do_add_rms_atomic = true;
9628+ }
9629+ break;
95509630 case GGML_OP_REPEAT:
95519631 case GGML_OP_REPEAT_BACK:
95529632 case GGML_OP_GET_ROWS:
9553- case GGML_OP_ADD:
95549633 case GGML_OP_ADD_ID:
95559634 case GGML_OP_ACC:
95569635 case GGML_OP_SUB:
@@ -9667,6 +9746,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96679746 // do the only thing needed for the dryrun.
96689747 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
96699748 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9749+ if (node->op == GGML_OP_RMS_NORM) {
9750+ ctx->do_add_rms_atomic = false;
9751+ }
96709752 return false;
96719753 }
96729754 default:
@@ -10581,6 +10663,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1058110663 vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
1058210664 }
1058310665
10666+ ctx->prealloc_size_atomic_add = 0;
10667+ ctx->prealloc_size_atomic_add_offset = 0;
10668+ ctx->do_add_rms_atomic = false;
10669+
1058410670 uint64_t total_mat_mul_bytes = 0;
1058510671 for (int i = 0; i < cgraph->n_nodes; i++) {
1058610672 if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
@@ -10641,6 +10727,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1064110727 compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
1064210728 }
1064310729
10730+ if (ctx->prealloc_size_atomic_add) {
10731+ if (ctx->compute_ctx.expired()) {
10732+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10733+ ctx->compute_ctx = compute_ctx;
10734+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
10735+ } else {
10736+ compute_ctx = ctx->compute_ctx.lock();
10737+ }
10738+ // initialize atomic sums to zero.
10739+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add);
10740+ }
10741+
1064410742 // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
1064510743 // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
1064610744 // (and scaled down based on model size, so smaller models submit earlier).
0 commit comments