From 94361fd5a8fa0f0ee6d747a707b2a29efad6116e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 14 Jun 2025 13:11:31 +0200 Subject: [PATCH 1/6] implement GLU for split up/gate --- ggml/include/ggml.h | 23 ++++++ ggml/src/ggml-cpu/ops.cpp | 150 ++++++++++++++++++++++++++++-------- ggml/src/ggml-cuda/unary.cu | 63 +++++++++------ ggml/src/ggml.c | 61 +++++++++++++-- src/llama-graph.cpp | 33 +++++--- 5 files changed, 258 insertions(+), 72 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 40ff1c187a831..3991d974f4fab 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1131,6 +1131,29 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // A: n columns, r rows, + // B: n columns, r rows, + GGML_API struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op); + + GGML_API struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8c88bf2e7b880..5543addcbdc00 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3201,14 +3201,24 @@ static void ggml_compute_forward_reglu_f32( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3224,10 +3234,15 @@ static void ggml_compute_forward_reglu_f32( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_reglu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3245,14 +3260,24 @@ static void ggml_compute_forward_reglu_f16( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3268,10 +3293,15 @@ static void ggml_compute_forward_reglu_f16( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_reglu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3314,14 +3344,24 @@ static void ggml_compute_forward_geglu_f32( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3337,10 +3377,15 @@ static void ggml_compute_forward_geglu_f32( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_geglu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3358,14 +3403,24 @@ static void ggml_compute_forward_geglu_f16( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3381,10 +3436,15 @@ static void ggml_compute_forward_geglu_f16( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_geglu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3427,14 +3487,24 @@ static void ggml_compute_forward_swiglu_f32( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3450,10 +3520,15 @@ static void ggml_compute_forward_swiglu_f32( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_swiglu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3471,14 +3546,24 @@ static void ggml_compute_forward_swiglu_f16( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(dst)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0] / 2; + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; const int nr = ggml_nrows(src0); GGML_ASSERT(dst->ne[0] == nc); @@ -3494,10 +3579,15 @@ static void ggml_compute_forward_swiglu_f16( const int ir1 = MIN(ir0 + dr, nr); for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_swiglu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); #ifndef NDEBUG for (int k = 0; k < nc; k++) { diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index c991c1d700174..ba3c0f13762b0 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -199,30 +199,36 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { /* gated ops */ template -static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o) { +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { return; } - // perform base op on half of the row and multiply with gate in other half - const int64_t j = (i / n) * o + (i % n); - dst[i] = (T)(op((float)x[j]) * (float)g[j]); + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } template -static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) { +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, g, dst, k, n, o); + unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); } template void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const void * src0_d = src0->data; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; void * dst_d = dst->data; - const int64_t nc = src0->ne[0] / 2; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous_1(src0)); @@ -235,26 +241,35 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; if (src0->type == GGML_TYPE_F16) { - unary_gated_cuda( - (const half *)src0_d + (swapped ? nc : 0), - (const half *)src0_d + (swapped ? 0 : nc), - (half *)dst_d, - ggml_nelements(dst), - nc, - src0->nb[1] / sizeof(half), - stream); + half * src0_p = (half *) src0_d; + half * src1_p = (half *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream); } else { - unary_gated_cuda( - (const float *)src0_d + (swapped ? nc : 0), - (const float *)src0_d + (swapped ? 0 : nc), - (float *)dst_d, - ggml_nelements(dst), - nc, - src0->nb[1] / sizeof(float), - stream); + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream); } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2ae4e511b543b..8972af5d5b9bb 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2644,37 +2644,68 @@ struct ggml_tensor * ggml_exp_inplace( // ggml_glu -struct ggml_tensor * ggml_glu( +static struct ggml_tensor * ggml_glu_impl( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * b, enum ggml_glu_op op, bool swapped) { GGML_ASSERT(ggml_is_contiguous_1(a)); + if (b) { + GGML_ASSERT(ggml_is_contiguous_1(b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(a->type == b->type); + } + int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0); ggml_set_op_params_i32(result, 0, (int32_t) op); ggml_set_op_params_i32(result, 1, (int32_t) swapped); result->op = GGML_OP_GLU; result->src[0] = a; + result->src[1] = b; return result; } +struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op, + bool swapped) { + return ggml_glu_impl(ctx, a, NULL, op, swapped); +} + +struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op) { + return ggml_glu_impl(ctx, a, b, op, false); +} + // ggml_reglu struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, false); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false); } struct ggml_tensor * ggml_reglu_swapped( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, true); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true); +} + +struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false); } // ggml_geglu @@ -2682,13 +2713,20 @@ struct ggml_tensor * ggml_reglu_swapped( struct ggml_tensor * ggml_geglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, false); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false); } struct ggml_tensor * ggml_geglu_swapped( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, true); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true); +} + +struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false); } // ggml_swiglu @@ -2696,13 +2734,20 @@ struct ggml_tensor * ggml_geglu_swapped( struct ggml_tensor * ggml_swiglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, false); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false); } struct ggml_tensor * ggml_swiglu_swapped( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, true); + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true); +} + +struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false); } // ggml_norm diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 75420f277d92c..25d08296075a8 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -604,12 +604,20 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_swiglu_split(ctx0, cur, tmp); + cb(cur, "ffn_swiglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); } break; case LLM_FFN_GELU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_geglu_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_gelu", il); if (act_scales != NULL) { @@ -618,7 +626,11 @@ ggml_tensor * llm_graph_context::build_ffn( } } break; case LLM_FFN_RELU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_reglu_split(ctx0, cur, tmp); + cb(cur, "ffn_reglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_relu(ctx0, cur); cb(cur, "ffn_relu", il); } break; @@ -774,12 +786,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: - { + if (gate_exps) { + cur = ggml_swiglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_swiglu", il); + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - { + if (gate_exps) { + cur = ggml_geglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_geglu", il); + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_moe_gelu", il); } break; @@ -787,11 +805,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } - if (gate_exps) { - cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens] - cb(cur, "ffn_moe_gate_par", il); - } - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); From ac3194dc363b4f81dd309d1b8ede94ae678c830f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 14 Jun 2025 16:09:19 +0200 Subject: [PATCH 2/6] add tests for ggml_glu_split --- tests/test-backend-ops.cpp | 57 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0278df1a98d66..a62eb883ba500 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1151,6 +1151,60 @@ struct test_glu : public test_case { } }; +struct test_glu_split : public test_case { + const ggml_glu_op op; + const ggml_type type; + const std::array ne_a; + int v; // view (1 : non-contiguous a) + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, v); + } + + test_glu_split(ggml_glu_op op, + ggml_type type = GGML_TYPE_F32, + std::array ne_a = {128, 2, 2, 2}, + int v = 0) + : op(op), type(type), ne_a(ne_a), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a; + ggml_tensor * b; + if (v & 1) { + auto ne = ne_a; ne[0] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + + b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(b, "b"); + + b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0); + ggml_set_name(a, "view_of_b"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + + b = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(b, "b"); + } + + ggml_tensor * out = ggml_glu_split(ctx, a, b, op); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // test extended range of values to check for NaNs in GELU + init_tensor_uniform(t, -150.f, 150.f); + } + } +}; + // GGML_OP_GET_ROWS struct test_get_rows : public test_case { const ggml_type type; @@ -3986,6 +4040,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped)); test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped)); } + + test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v)); + test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v)); } } } From be3f78c808cfc125168f353753b33abf0522c472 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 15 Jun 2025 06:12:54 +0000 Subject: [PATCH 3/6] Vulkan: Implement glu_split logic and shader support --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 38 +++++++++++---- .../src/ggml-vulkan/vulkan-shaders/geglu.comp | 46 ++++--------------- .../ggml-vulkan/vulkan-shaders/glu_head.comp | 15 ++++++ .../ggml-vulkan/vulkan-shaders/glu_main.comp | 31 +++++++++++++ .../src/ggml-vulkan/vulkan-shaders/reglu.comp | 37 ++------------- .../ggml-vulkan/vulkan-shaders/swiglu.comp | 39 ++-------------- 6 files changed, 93 insertions(+), 113 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ee59f3a59957e..4a347bc5efcd6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -659,6 +659,11 @@ struct vk_op_push_constants { float param2; }; +struct vk_op_glu_push_constants { + uint32_t ne00; + uint32_t mode; // 0: default, 1: swapped, 2: split +}; + struct vk_op_unary_push_constants { uint32_t ne; uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; @@ -2733,8 +2738,8 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef CREATE_UNARY #define CREATE_GLU(name) \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); CREATE_GLU(geglu) CREATE_GLU(reglu) @@ -6947,7 +6952,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_SOFT_MAX) { + if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) { // Empty src1 is possible in soft_max, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { @@ -7539,12 +7544,23 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } -static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); +static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const bool swapped = (bool)dst->op_params[1]; + const bool split = src1 != nullptr; + + GGML_ASSERT(ggml_is_contiguous(src0)); + + if (!split) { + GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); + } else { + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->type == src1->type); + } - const uint32_t swapped = (uint32_t)dst->op_params[1]; + const uint32_t mode = split ? 2 : (swapped ? 1 : 0); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun); } static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9003,7 +9019,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: - ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun); + ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); break; default: return false; @@ -10725,7 +10741,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { GGML_ABORT("fatal error"); } } else if (tensor->op == GGML_OP_GLU) { - tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + if (src_clone[1] == nullptr) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + } else { + tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); + } } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp index e58ac59d9a860..f4268ed24f44c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -1,43 +1,13 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "glu_head.comp" -#extension GL_EXT_control_flow_attributes : enable +const float GELU_COEF_A = 0.044715f; +const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -void main() { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint col = gl_LocalInvocationID.x; - - const uint offset = p.KX / 2; - - const bool swapped = p.KY > 0; - - if (!swapped) { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - const float xi = float(data_a[idx]); - const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); - data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx + offset])); - } - } else { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - const float xi = float(data_a[idx + offset]); - const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); - data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx])); - } - } +float op(float a, float b) { + const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a); + return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; } + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp new file mode 100644 index 0000000000000..0d65baef38944 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -0,0 +1,15 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +layout (push_constant) uniform parameter +{ + uint ne00; + uint mode; +} p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp new file mode 100644 index 0000000000000..24814240365d2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp @@ -0,0 +1,31 @@ +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.ne00 + i; + + data_d[row * offset + i] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + + for (uint i = col; i < offset; i += BLOCK_SIZE) { + const uint idx = row * p.ne00 + i; + + data_d[row * offset + i] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } + } else { + // Split + for (uint i = col; i < p.ne00; i += BLOCK_SIZE) { + const uint idx = row * p.ne00 + i; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp index 034481a1f17ef..0073d8f766610 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -1,36 +1,9 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "glu_head.comp" -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint col = gl_LocalInvocationID.x; - - const uint offset = p.KX / 2; - - const bool swapped = p.KY > 0; - - if (!swapped) { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - data_d[row * offset + i] = D_TYPE(max(float(data_a[idx]), 0.0f) * float(data_a[idx + offset])); - } - } else { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - data_d[row * offset + i] = D_TYPE(max(float(data_a[idx + offset]), 0.0f) * float(data_a[idx])); - } - } +float op(float a, float b) { + return max(a, 0.0f) * b; } + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp index e75c1d38aa1ea..a28e7c6cc8660 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -1,38 +1,9 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "glu_head.comp" -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint col = gl_LocalInvocationID.x; - - const uint offset = p.KX / 2; - - const bool swapped = p.KY > 0; - - if (!swapped) { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - const float xi = float(data_a[idx]); - data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx + offset])); - } - } else { - for (uint i = col; i < offset; i += BLOCK_SIZE) { - const uint idx = row * p.KX + i; - - const float xi = float(data_a[idx + offset]); - data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx])); - } - } +float op(float a, float b) { + return a / (1.0f + exp(-a)) * b; } + +#include "glu_main.comp" From 06362b0d43eafe29ddbf98fd5e52028730b271ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 15 Jun 2025 08:54:42 +0200 Subject: [PATCH 4/6] add split to logging [no ci] --- tests/test-backend-ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a62eb883ba500..757924ac01d70 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1158,7 +1158,7 @@ struct test_glu_split : public test_case { int v; // view (1 : non-contiguous a) std::string vars() override { - return VARS_TO_STR3(type, ne_a, v); + return VARS_TO_STR3(type, ne_a, v) + ",split"; } test_glu_split(ggml_glu_op op, From 42c2870f264c1ca0d82fe6b41d2749f6913b01a0 Mon Sep 17 00:00:00 2001 From: Akarshan Date: Sun, 15 Jun 2025 21:31:42 +0530 Subject: [PATCH 5/6] SYCL: refactor element_size ops and add split up and gate support to gated kernels --- ggml/src/ggml-sycl/element_wise.cpp | 1738 ++++++++------------------- ggml/src/ggml-sycl/element_wise.hpp | 17 +- 2 files changed, 502 insertions(+), 1253 deletions(-) diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 7e6b48db7002b..769c75be87678 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -2,14 +2,20 @@ #include "ggml-sycl/presets.hpp" #include "ggml.h" #include "element_wise.hpp" -#include -#include + +// --- Helper Macros for Kernel Indexing --- +#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \ + for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0)) + +#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \ + (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX)) + +// --- Original Kernels (non-_sycl) - Modified to use indexing macros and cast literals --- static void acc_f32(const float * x, const float * y, float * dst, const int ne, const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); + const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) { + const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0); if (i >= ne) { return; } @@ -25,72 +31,59 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne, } template -static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { +static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { dst[i] = x[i] > static_cast(0.f) ? static_cast(1.f) : ((x[i] < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); } } template -static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { +static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { dst[i] = sycl::fabs(x[i]); } } template -static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { +static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { dst[i] = (x[i] > static_cast(0.f)) ? x[i] : sycl::expm1(x[i]); } } template static void gelu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { + const sycl::nd_item<1> &item_ct1) { const T GELU_COEF_A = static_cast(0.044715f); const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = static_cast(0.5f) * x[i] * + (static_cast(1.0f) + + sycl::tanh(SQRT_2_OVER_PI * x[i] * (static_cast(1.0f) + GELU_COEF_A * x[i] * x[i]))); } - - float xi = x[i]; - dst[i] = static_cast(0.5f) * xi * - (static_cast(1.0f) + - sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast(1.0f) + GELU_COEF_A * xi * xi))); } template static void silu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] / (static_cast(1.0f) + sycl::native::exp(-x[i])); } - dst[i] = x[i] / (static_cast(1.0f) + sycl::native::exp(-x[i])); } template static void gelu_quick(const T *x, T *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const float GELU_QUICK_COEF = -1.702f; - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x[i]))); } - dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i]))); } template -static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { +static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { auto x_i = x[i]; dst[i] = static_cast(0.5f) * x_i * (static_cast(1.0f) + sycl::erf(x_i * SQRT_2_INV)); } @@ -98,174 +91,121 @@ static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> & template static void tanh(const T *x, T *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::tanh((x[i])); } - dst[i] = sycl::tanh((x[i])); } template static void relu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::fmax((x[i]), static_cast(0)); } - dst[i] = sycl::fmax((x[i]), static_cast(0)); } template static void sigmoid(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(-x[i])); } - dst[i] = 1.0f / (static_cast(1.0f) + sycl::native::exp(-x[i])); } template static void sqrt(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::sqrt(x[i]); } - dst[i] = sycl::sqrt(x[i]); } template static void sin(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::sin(x[i]); } - dst[i] = sycl::sin(x[i]); } template static void cos(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::cos(x[i]); } - dst[i] = sycl::cos(x[i]); } template static void hardsigmoid(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } - dst[i] = sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } template static void hardswish(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } - dst[i] = x[i] * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } template static void exp(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = sycl::exp(x[i]); } - dst[i] = sycl::exp(x[i]); } template static void log(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - T xi = x[i]; - if (xi <= 0) { - dst[i] = neg_infinity(); - } else { - dst[i] = sycl::log(xi); + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + T xi = x[i]; + if (xi <= static_cast(0)) { + dst[i] = neg_infinity(); + } else { + dst[i] = sycl::log(xi); + } } } template static void neg(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = -x[i]; } - dst[i] = -x[i]; } template static void step(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = (x[i] > static_cast(0.0f)) ? static_cast(1.0f) : static_cast(0.0f); } - dst[i] = x[i] > static_cast(0.0f); } template static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + T neg_slope_T = static_cast(negative_slope); + dst[i] = sycl::fmax((x[i]), static_cast(0)) + + sycl::fmin((x[i]), static_cast(0.0f)) * neg_slope_T; } - dst[i] = sycl::fmax((x[i]), static_cast(0)) + - sycl::fmin((x[i]), static_cast(0.0f)) * negative_slope; } template static void sqr(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] * x[i]; } - dst[i] = x[i] * x[i]; } template @@ -284,10 +224,10 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01, int i12 = (index / (ne10 * ne11)) % ne12; int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - int i00 = i10 / sf0; - int i01 = i11 / sf1; - int i02 = i12 / sf2; - int i03 = i13 / sf3; + int i00 = static_cast(i10 / sf0); + int i01 = static_cast(i11 / sf1); + int i02 = static_cast(i12 / sf2); + int i03 = static_cast(i13 / sf3); dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } @@ -295,8 +235,7 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01, template static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02, const sycl::nd_item<3> &item_ct1) { - int nidx = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); + int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2); if (nidx >= ne0) { return; } @@ -313,337 +252,72 @@ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne } } - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); } - - dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); } -// Fused GLU kernels +// Fused GLU kernels (unchanged logic) template -static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { - for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { - const int64_t j = ((i / n) * o) + (i % n); - const T x_val = x[j]; - const T gelu_val = x_val * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val))); - - dst[i] = gelu_val * g[j]; +static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + const T x_val = x[j0]; + const T gelu_val = x_val * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x_val))); + + dst[i] = gelu_val * g[j1]; } } template -static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { - for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { - const int64_t j = ((i / n) * o) + (i % n); - dst[i] = sycl::max((x[j]), static_cast(0)) * g[j]; +static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = sycl::max((x[j0]), static_cast(0)) * g[j1]; } } template -static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { - for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { - const int64_t j = ((i / n) * o) + (i % n); - dst[i] = (x[j] / (static_cast(1) + sycl::native::exp(-x[j]))) * g[j]; +static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = (x[j0] / (static_cast(1) + sycl::native::exp(-x[j0]))) * g[j1]; } } +// --- Generic SYCL Kernel Launchers --- +namespace ggml_sycl_detail { +// acc_f32_sycl remains specific static void acc_f32_sycl(const float *x, const float *y, float *dst, const int n_elements, const int ne10, const int ne11, const int ne12, const int nb1, const int nb2, const int offset, queue_ptr stream) { - int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; + int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE); stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + sycl::nd_range<1>(sycl::range<1>(num_blocks) * + sycl::range<1>(SYCL_ACC_BLOCK_SIZE), + sycl::range<1>(SYCL_ACC_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1); }); } -template -static void gelu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu(x, dst, k, item_ct1); - }); -} - -template -static void silu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - silu(x, dst, k, item_ct1); - }); -} - -template -static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - stream->parallel_for( - sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - sgn(x, dst, k, item_ct1); - }); -} - -template -static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - stream->parallel_for( - sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - abs_op(x, dst, k, item_ct1); - }); -} - - -template -static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - stream->parallel_for( - sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { - elu_op(x, dst, k, item_ct1); - }); -} - -template -static void gelu_quick_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu_quick(x, dst, k, item_ct1); - }); -} - - -template -static void gelu_erf_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu_erf(x, dst, k, item_ct1); - }); -} - -template -static void tanh_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - tanh(x, dst, k, item_ct1); - }); -} - -template -static void relu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - relu(x, dst, k, item_ct1); - }); -} - -template -static void hardsigmoid_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - hardsigmoid(x, dst, k, item_ct1); - }); -} - -template -static void hardswish_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - hardswish(x, dst, k, item_ct1); - }); -} - -template -static void exp_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - exp(x, dst, k, item_ct1); - }); -} - -template -static void log_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - log(x, dst, k, item_ct1); - }); -} - -template -static void neg_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - neg(x, dst, k, item_ct1); - }); -} - -template -static void step_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - step(x, dst, k, item_ct1); - }); -} - -template -static void sigmoid_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sigmoid(x, dst, k, item_ct1); - }); -} - -template -static void sqrt_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sqrt(x, dst, k, item_ct1); - }); -} - -template -static void sin_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sin(x, dst, k, item_ct1); - }); -} - -template -static void cos_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cos(x, dst, k, item_ct1); - }); -} - -template -static void leaky_relu_sycl(const T *x, T *dst, const int k, - const float negative_slope, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - leaky_relu(x, dst, k, negative_slope, item_ct1); - }); -} - -template -static void sqr_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sqr(x, dst, k, item_ct1); - }); -} - +// upscale_sycl remains specific template static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, queue_ptr stream) { int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE); sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); stream->parallel_for( sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), @@ -652,11 +326,12 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, }); } +// pad_sycl remains specific template static void pad_sycl(const T *x, T *dst, const int ne00, const int ne01, const int ne02, const int ne0, const int ne1, const int ne2, queue_ptr stream) { - int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; + int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE); sycl::range<3> gridDim(ne2, ne1, num_blocks); stream->parallel_for( sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), @@ -666,52 +341,13 @@ static void pad_sycl(const T *x, T *dst, const int ne00, }); } -template -static void clamp_sycl(const T *x, T *dst, const float min, - const float max, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - clamp(x, dst, min, max, k, item_ct1); - }); -} - -template -static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { - const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); - main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1); - }); -} - -template -static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { - const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE); - main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1); - }); -} - -template -static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { - const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE); - main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1); - }); -} - -inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// Common dispatcher for 1-input, 1-output element-wise ops, handling type switching. +// KernelInvoker is a lambda that takes (const T* src, T* dst, int k, queue_ptr stream, Args...) +template +inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - #else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -724,14 +360,14 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } default: @@ -739,11 +375,12 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } } -inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// Dispatcher for fused GLU ops, handling specific input pointer setup and type switching. +template +inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - #else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -751,19 +388,66 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;; + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); + GGML_ASSERT(ggml_is_contiguous(dst)); + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } switch (dst->type) { #if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { - auto data_pts = cast_data(dst); - abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + sycl::half * src0_p = (sycl::half *) src0_d; + sycl::half * src1_p = (sycl::half *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + kernel_invoker(src0_p, + src1_p, + (sycl::half *) dst_d, + ggml_nelements(dst), + nc, + src0_o / sizeof(sycl::half), + src1_o / sizeof(sycl::half), + main_stream, + std::forward(args)...); break; } #endif case GGML_TYPE_F32: { - auto data_pts = cast_data(dst); - abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + kernel_invoker(src0_p, + src1_p, + (float *) dst_d, + ggml_nelements(dst), + nc, + src0_o / sizeof(float), + src1_o / sizeof(float), + main_stream, + std::forward(args)...); break; } default: @@ -771,32 +455,42 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } } - -inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// Dispatcher for upscale +template +inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - #else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); #endif GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; + const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; + const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; + const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; switch (dst->type) { #if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], + (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, + main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], + (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, + main_stream, std::forward(args)...); break; } default: @@ -804,7 +498,9 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } } -inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// Dispatcher for pad +template +inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -813,6 +509,7 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->type == GGML_TYPE_F32); #endif GGML_ASSERT(dst->src[0]->type == dst->type); + GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); switch (dst->type) { @@ -820,14 +517,16 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], + (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], + (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); break; } default: @@ -835,655 +534,321 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } } +} // namespace ggml_sycl_detail + + +// --- Backend Operation Functions (ggml_sycl_op_...) --- + +inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + sgn(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + abs_op(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + elu_op(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE), + sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + silu(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + gelu(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + gelu_quick(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + gelu_erf(src, dst_ptr, k_elements, item_ct1); + }); + }); } - inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE), + sycl::range<1>(SYCL_TANH_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + tanh(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + relu(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE), + sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + hardsigmoid(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE), + sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + hardswish(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), + sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + exp(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), + sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + log(src, dst_ptr, k_elements, item_ct1); + }); + }); } -inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } +inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE), + sycl::range<1>(SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + neg(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE), + sycl::range<1>(SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + step(src, dst_ptr, k_elements, item_ct1); + }); + }); } -inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif +inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE), + sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + sigmoid(src, dst_ptr, k_elements, item_ct1); + }); + }); +} - GGML_ASSERT(dst->src[0]->type == dst->type); +inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE), + sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + sqrt(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), + sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + sin(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), + sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + cos(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float negative_slope; memcpy(&negative_slope, dst->op_params, sizeof(float)); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) { + const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + leaky_relu(src, dst_ptr, k_elements, slope, item_ct1); + }); + }, negative_slope); } inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - #if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE), + sycl::range<1>(SYCL_SQR_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + sqr(src, dst_ptr, k_elements, item_ct1); + }); + }); } inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; - const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; - const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; - const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], - dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], - dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst, + [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03, + int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3, + queue_ptr stream) { + ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream); + }); } inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], - dst->ne[1], dst->ne[2], main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], - dst->ne[1], dst->ne[2], main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst, + [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, + queue_ptr stream) { + ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream); + }); } inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined(GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - float min; - float max; - memcpy(&min, dst->op_params, sizeof(float)); - memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - - switch (dst->type) { -#if defined(GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + float min_val; + float max_val; + memcpy(&min_val, dst->op_params, sizeof(float)); + memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float)); + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) { + const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE), + sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1); + }); + }, min_val, max_val); } inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -1499,156 +864,43 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused int offset = dst->op_params[3] / 4; // offset in bytes - acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream); + ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream); } inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const int64_t nc = dst->src[0]->ne[0] / 2; - GGML_ASSERT(dst->ne[0] == nc); - GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); - GGML_ASSERT(ggml_is_contiguous(dst)); - const int32_t swapped = ((const int32_t *) dst->op_params)[1]; - const void * src0_d = dst->src[0]->data; - void * dst_d = dst->data; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), - (const sycl::half *)src0_d + (swapped ? 0 : nc), - (sycl::half *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(sycl::half), - main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - geglu_sycl((const float *) src0_d + (swapped ? nc : 0), - (const float *)src0_d + (swapped ? 0 : nc), - (float *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(float), - main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); } inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const int64_t nc = dst->src[0]->ne[0] / 2; - GGML_ASSERT(dst->ne[0] == nc); - GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); - GGML_ASSERT(ggml_is_contiguous(dst)); - const int32_t swapped = ((const int32_t *) dst->op_params)[1]; - const void * src0_d = dst->src[0]->data; - void * dst_d = dst->data; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), - (const sycl::half *)src0_d + (swapped ? 0 : nc), - (sycl::half *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(sycl::half), - main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - reglu_sycl((const float *) src0_d + (swapped ? nc : 0), - (const float *)src0_d + (swapped ? 0 : nc), - (float *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(float), - main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); } inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const int64_t nc = dst->src[0]->ne[0] / 2; - GGML_ASSERT(dst->ne[0] == nc); - GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); - GGML_ASSERT(ggml_is_contiguous(dst)); - const int32_t swapped = ((const int32_t *) dst->op_params)[1]; - const void * src0_d = dst->src[0]->data; - void * dst_d = dst->data; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), - (const sycl::half *)src0_d + (swapped ? 0 : nc), - (sycl::half *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(sycl::half), - main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - swiglu_sycl((const float *) src0_d + (swapped ? nc : 0), - (const float *)src0_d + (swapped ? 0 : nc), - (float *) dst_d, - ggml_nelements(dst), - nc, - dst->src[0]->nb[1] / sizeof(float), - main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu + main_stream->parallel_for( + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); } + void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sqrt(ctx, dst); @@ -1788,5 +1040,3 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_swiglu(ctx, dst); } - - diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index f530c9c1e1bdd..86068b10129ec 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -3,24 +3,24 @@ #include "common.hpp" #include "ggml.h" -#include +#include // For std::numeric_limits template T neg_infinity() { return -std::numeric_limits::infinity(); } -template +template struct typed_data { - const T * src; - T * dst; + const T_Src * src; + T_Dst * dst; }; -template -typed_data cast_data(ggml_tensor * dst) { +template +typed_data cast_data(ggml_tensor * dst) { return { - /* .src = */ static_cast(dst->src[0]->data), - /* .dst = */ static_cast(dst->data) + /* .src = */ static_cast(dst->src[0]->data), + /* .dst = */ static_cast(dst->data) }; } @@ -82,4 +82,3 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_ELEMENTWISE_HPP - From 832efa9ea12c796f63f7748b206389856676d2b0 Mon Sep 17 00:00:00 2001 From: Akarshan Date: Tue, 17 Jun 2025 11:53:46 +0530 Subject: [PATCH 6/6] SYCL: switch GEGLU to use tanh approximation --- ggml/src/ggml-sycl/element_wise.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 769c75be87678..828cea1aa0086 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -260,15 +260,18 @@ static void clamp(const T * x, T * dst, const float min, const float max, const } } -// Fused GLU kernels (unchanged logic) template static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { - const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + const T GELU_COEF_A = static_cast(0.044715f); + const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); SYCL_GLOBAL_ID_LOOP(k, item_ct1) { const int64_t j0 = (i / n) * o0 + (i % n); const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); const T x_val = x[j0]; - const T gelu_val = x_val * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x_val))); + + const T x_cubed_term = static_cast(1.0f) + GELU_COEF_A * x_val * x_val; + const T tanh_input = SQRT_2_OVER_PI * x_val * x_cubed_term; + const T gelu_val = static_cast(0.5f) * x_val * (static_cast(1.0f) + sycl::tanh(tanh_input)); dst[i] = gelu_val * g[j1]; }