@@ -420,9 +420,9 @@ struct ggml_backend_opencl_context {
420420 cl_kernel kernel_clamp;
421421 cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
422422 kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
423- cl_kernel kernel_norm;
423+ cl_kernel kernel_norm, kernel_norm_mul_add ;
424424 cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
425- cl_kernel kernel_group_norm;
425+ cl_kernel kernel_group_norm, kernel_group_norm_mul_add ;
426426 cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
427427 cl_kernel kernel_soft_max, kernel_soft_max_4;
428428 cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
@@ -1161,7 +1161,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11611161 backend_ctx->program_norm =
11621162 build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
11631163
1164- CL_CHECK ((backend_ctx->kernel_norm = clCreateKernel (backend_ctx->program_norm , " kernel_norm" , &err), err));
1164+ CL_CHECK ((backend_ctx->kernel_norm = clCreateKernel (backend_ctx->program_norm , " kernel_norm" , &err), err));
1165+ CL_CHECK ((backend_ctx->kernel_norm_mul_add = clCreateKernel (backend_ctx->program_norm , " kernel_norm_mul_add" , &err), err));
11651166 GGML_LOG_CONT (" ." );
11661167 }
11671168
@@ -1487,7 +1488,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
14871488 backend_ctx->program_group_norm =
14881489 build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
14891490
1490- CL_CHECK ((backend_ctx->kernel_group_norm = clCreateKernel (backend_ctx->program_group_norm , " kernel_group_norm" , &err), err));
1491+ CL_CHECK ((backend_ctx->kernel_group_norm = clCreateKernel (backend_ctx->program_group_norm , " kernel_group_norm" , &err), err));
1492+ CL_CHECK ((backend_ctx->kernel_group_norm_mul_add = clCreateKernel (backend_ctx->program_group_norm , " kernel_group_norm_mul_add" , &err), err));
14911493 GGML_LOG_CONT (" ." );
14921494 }
14931495
@@ -2498,12 +2500,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
24982500 if (!ggml_is_contiguous_rows (mul->src [0 ]) || !ggml_is_contiguous_rows (mul->src [1 ])) {
24992501 return false ;
25002502 }
2503+ } else if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_NORM && ops.begin ()[1 ] == GGML_OP_MUL && ops.begin ()[2 ] == GGML_OP_ADD) {
2504+ const ggml_tensor *norm = cgraph->nodes [node_idx];
2505+ const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
2506+ const ggml_tensor *add = cgraph->nodes [node_idx+2 ];
2507+ const ggml_tensor *w = mul->src [0 ] == norm ? mul->src [1 ] : mul->src [0 ];
2508+ const ggml_tensor *b = add->src [0 ] == mul ? add->src [1 ] : add->src [0 ];
2509+
2510+ // norm fusion only supports F32
2511+ if (norm->src [0 ]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
2512+ return false ;
2513+ }
2514+
2515+ if (norm->src [0 ]->ne [0 ] % 4 != 0 ) {
2516+ return false ;
2517+ }
2518+
2519+ if (!ggml_is_contiguous (norm->src [0 ]) || !ggml_is_contiguous (w) || !ggml_is_contiguous (b)) {
2520+ return false ;
2521+ }
2522+ } else if (ops.size () == 3 && ops.begin ()[0 ] == GGML_OP_GROUP_NORM && ops.begin ()[1 ] == GGML_OP_MUL && ops.begin ()[2 ] == GGML_OP_ADD) {
2523+ const ggml_tensor *gn = cgraph->nodes [node_idx];
2524+ const ggml_tensor *mul = cgraph->nodes [node_idx+1 ];
2525+ const ggml_tensor *add = cgraph->nodes [node_idx+2 ];
2526+ const ggml_tensor *w = mul->src [0 ] == gn ? mul->src [1 ] : mul->src [0 ];
2527+ const ggml_tensor *b = add->src [0 ] == mul ? add->src [1 ] : add->src [0 ];
2528+
2529+ if (gn->src [0 ]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
2530+ return false ;
2531+ }
2532+
2533+ if (!ggml_is_contiguous (gn->src [0 ]) || !ggml_is_contiguous (w) || !ggml_is_contiguous (b)) {
2534+ return false ;
2535+ }
25012536 }
25022537
25032538 return true ;
25042539}
25052540
25062541static void ggml_opencl_op_rms_norm_fused (ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
2542+ static void ggml_opencl_op_norm_fused (ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
2543+ static void ggml_opencl_op_group_norm_fused (ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
25072544
25082545static ggml_status ggml_backend_opencl_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
25092546 ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
@@ -2520,6 +2557,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
25202557 continue ;
25212558 }
25222559
2560+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse (cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
2561+ ggml_opencl_op_norm_fused (backend, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
2562+ i += 2 ;
2563+ continue ;
2564+ }
2565+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse (cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
2566+ ggml_opencl_op_group_norm_fused (backend, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
2567+ i += 2 ;
2568+ continue ;
2569+ }
25232570 if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
25242571 ggml_opencl_op_rms_norm_fused (backend, node, cgraph->nodes [i+1 ]);
25252572 i++;
@@ -5039,6 +5086,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
50395086 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
50405087}
50415088
5089+ static void ggml_opencl_op_norm_fused (ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
5090+ GGML_ASSERT (norm_tensor && mul_tensor && add_tensor);
5091+
5092+ const ggml_tensor * src0 = norm_tensor->src [0 ];
5093+ const ggml_tensor * src1 = mul_tensor->src [0 ] == norm_tensor ? mul_tensor->src [1 ] : mul_tensor->src [0 ];
5094+ const ggml_tensor * src2 = add_tensor->src [0 ] == mul_tensor ? add_tensor->src [1 ] : add_tensor->src [0 ];
5095+ const ggml_tensor * dst = add_tensor;
5096+
5097+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
5098+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
5099+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra ;
5100+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5101+
5102+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
5103+ cl_ulong offset1 = extra1->offset + src1->view_offs ;
5104+ cl_ulong offset2 = extra2->offset + src2->view_offs ;
5105+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
5106+
5107+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5108+
5109+ float eps;
5110+ memcpy (&eps, norm_tensor->op_params , sizeof (float ));
5111+
5112+ const int ne00 = src0->ne [0 ], ne01 = src0->ne [1 ], ne02 = src0->ne [2 ], ne03 = src0->ne [3 ];
5113+ const cl_ulong nb01 = src0->nb [1 ], nb02 = src0->nb [2 ], nb03 = src0->nb [3 ];
5114+ const int ne10 = src1->ne [0 ], ne11 = src1->ne [1 ], ne12 = src1->ne [2 ], ne13 = src1->ne [3 ];
5115+ const cl_ulong nb11 = src1->nb [1 ], nb12 = src1->nb [2 ], nb13 = src1->nb [3 ];
5116+ const int ne20 = src2->ne [0 ], ne21 = src2->ne [1 ], ne22 = src2->ne [2 ], ne23 = src2->ne [3 ];
5117+ const cl_ulong nb21 = src2->nb [1 ], nb22 = src2->nb [2 ], nb23 = src2->nb [3 ];
5118+ const cl_ulong nbd1 = dst->nb [1 ], nbd2 = dst->nb [2 ], nbd3 = dst->nb [3 ];
5119+
5120+ size_t sgs;
5121+ if (backend_ctx->gpu_family == ADRENO) sgs = 64 ;
5122+ else if (backend_ctx->gpu_family == INTEL) sgs = 32 ;
5123+ else GGML_ASSERT (false && " Unsupported GPU" );
5124+
5125+ cl_kernel kernel = backend_ctx->kernel_norm_mul_add ;
5126+
5127+ int nth = sgs;
5128+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size (kernel);
5129+ while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2 ;
5130+ nth = MIN (nth, max_workgroup_size);
5131+ nth = MIN (nth, ne00/4 );
5132+
5133+ size_t gws[] = {(size_t )ne01*nth, (size_t )ne02, (size_t )ne03};
5134+ size_t lws[] = {(size_t )nth, 1 , 1 };
5135+ size_t num_subgroups = (nth + sgs - 1 ) / sgs;
5136+
5137+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
5138+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
5139+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
5140+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
5141+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extra2->data_device ));
5142+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
5143+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
5144+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
5145+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne00));
5146+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &ne01));
5147+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ), &ne02));
5148+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ), &ne03));
5149+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb01));
5150+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (cl_ulong), &nb02));
5151+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (cl_ulong), &nb03));
5152+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &ne10));
5153+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ), &ne11));
5154+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ), &ne12));
5155+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ), &ne13));
5156+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &nb11));
5157+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb12));
5158+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (cl_ulong), &nb13));
5159+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &ne20));
5160+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &ne21));
5161+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &ne22));
5162+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &ne23));
5163+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (cl_ulong), &nb21));
5164+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (cl_ulong), &nb22));
5165+ CL_CHECK (clSetKernelArg (kernel, 28 , sizeof (cl_ulong), &nb23));
5166+ CL_CHECK (clSetKernelArg (kernel, 29 , sizeof (cl_ulong), &nbd1));
5167+ CL_CHECK (clSetKernelArg (kernel, 30 , sizeof (cl_ulong), &nbd2));
5168+ CL_CHECK (clSetKernelArg (kernel, 31 , sizeof (cl_ulong), &nbd3));
5169+ CL_CHECK (clSetKernelArg (kernel, 32 , sizeof (float ), &eps));
5170+ CL_CHECK (clSetKernelArg (kernel, 33 , sizeof (cl_float2) * num_subgroups, NULL ));
5171+
5172+ backend_ctx->enqueue_ndrange_kernel (kernel, 3 , gws, lws, dst);
5173+ }
5174+
5175+ static void ggml_opencl_op_group_norm_fused (ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
5176+ GGML_ASSERT (gn_tensor && mul_tensor && add_tensor);
5177+
5178+ const ggml_tensor * src0 = gn_tensor->src [0 ];
5179+ const ggml_tensor * src1 = mul_tensor->src [0 ] == gn_tensor ? mul_tensor->src [1 ] : mul_tensor->src [0 ];
5180+ const ggml_tensor * src2 = add_tensor->src [0 ] == mul_tensor ? add_tensor->src [1 ] : add_tensor->src [0 ];
5181+ const ggml_tensor * dst = add_tensor;
5182+
5183+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra ;
5184+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
5185+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra ;
5186+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5187+
5188+ cl_ulong offset0 = extra0->offset + src0->view_offs ;
5189+ cl_ulong offset1 = extra1->offset + src1->view_offs ;
5190+ cl_ulong offset2 = extra2->offset + src2->view_offs ;
5191+ cl_ulong offsetd = extrad->offset + dst->view_offs ;
5192+
5193+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5194+
5195+ int groups;
5196+ float eps;
5197+ memcpy (&groups, gn_tensor->op_params , sizeof (int ));
5198+ memcpy (&eps, (char *)gn_tensor->op_params + sizeof (int ), sizeof (float ));
5199+
5200+ cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add ;
5201+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size (kernel);
5202+ int ne = ggml_nelements (src0);
5203+ int group_size = ne / groups;
5204+
5205+ size_t lws[] = { (size_t )MIN (max_workgroup_size, group_size) };
5206+ size_t gws[] = { (size_t )groups * lws[0 ] };
5207+
5208+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
5209+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &offset0));
5210+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra1->data_device ));
5211+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &offset1));
5212+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (cl_mem), &extra2->data_device ));
5213+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (cl_ulong), &offset2));
5214+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (cl_mem), &extrad->data_device ));
5215+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (cl_ulong), &offsetd));
5216+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &ne));
5217+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &group_size));
5218+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (float ), &eps));
5219+
5220+ backend_ctx->enqueue_ndrange_kernel (kernel, 1 , gws, lws, dst);
5221+ }
5222+
50425223static void ggml_cl_group_norm (ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
50435224 GGML_ASSERT (src0);
50445225 GGML_ASSERT (src0->extra );
0 commit comments