From 0d71ffaafb8907ffa9244771851623511bd18453 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Fri, 23 May 2025 16:46:08 +0100 Subject: [PATCH 1/7] [WIP]: fuse q8 quantization and reorder --- ggml/src/ggml-sycl/common.hpp | 17 +------ ggml/src/ggml-sycl/ggml-sycl.cpp | 81 +++++++++++++++++++++++++++----- ggml/src/ggml-sycl/mmvq.cpp | 11 ++++- ggml/src/ggml-sycl/vecdotq.hpp | 30 ++++++------ 4 files changed, 94 insertions(+), 45 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 60909dde7d087..618422099082e 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -280,22 +280,7 @@ void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector str inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) { optimize_feature opt; - opt.reorder = - (arch == syclex::architecture::intel_gpu_dg1 || - arch == syclex::architecture::intel_gpu_acm_g10 || - arch == syclex::architecture::intel_gpu_acm_g11 || - arch == syclex::architecture::intel_gpu_acm_g12 || - arch == syclex::architecture::intel_gpu_pvc || - arch == syclex::architecture::intel_gpu_pvc_vg || - arch == syclex::architecture::intel_gpu_mtl_u || - arch == syclex::architecture::intel_gpu_mtl_s || - arch == syclex::architecture::intel_gpu_mtl_h || - arch == syclex::architecture::intel_gpu_arl_u || - arch == syclex::architecture::intel_gpu_arl_s || - arch == syclex::architecture::intel_gpu_arl_h || - arch == syclex::architecture::intel_gpu_bmg_g21 || - arch == syclex::architecture::intel_gpu_lnl_m - ); + opt.reorder = true; return opt; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 271f54e5773d9..0e40b6464c8ed 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -33,6 +33,7 @@ #include #include "ggml-sycl.h" +#include "common.hpp" #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -44,6 +45,7 @@ #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" #include "ggml.h" +#include "presets.hpp" static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; @@ -1412,6 +1414,45 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, reinterpret_cast(y[ib].ds.y()) = sum; } +template +static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, int8_t * quant_ptr, + sycl::half2 * ds_ptr, const sycl::nd_item<1> & it) { + auto subgroup_id = it.get_group(0); + auto wi_id = it.get_local_id(0); + + sycl::vec wi_f32_vals; + sycl::vec quantized_values; + + auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; + wi_f32_vals = *reinterpret_cast*>(x + float_ptr_offset); + + float sum = 0.0f; + float amax = 0.0f; + +#pragma unroll(ElementsPerWI) + for (int i = 0; i < ElementsPerWI; i++) { + sum += wi_f32_vals[i]; + amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); + quantized_values[i] = 0; + } + sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus()); + amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum()); + const float d = amax == 0 ? 1 : amax / 127; + +#pragma unroll(ElementsPerWI) + for (int i = 0; i < ElementsPerWI; i++) { + quantized_values[i] = sycl::round(wi_f32_vals[i] / d); + } + + *reinterpret_cast*>(quant_ptr + subgroup_id * QK8_1) = quantized_values; + auto my_val = *reinterpret_cast*>(quant_ptr + subgroup_id * QK8_1); + ds_ptr[subgroup_id] = sycl::half2(sycl::half(d), sycl::half(sum)); + auto ds_values = ds_ptr[subgroup_id]; + float sum_value = ds_values[0]; + float d_value = ds_values[1]; + //sycl::ext::oneapi::experimental::printf("%d %d %f %f \n", static_cast(subgroup_id), my_val[0], sum_value, d_value); +} + static void mul_mat_p021_f16_f32( const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, @@ -1699,20 +1740,33 @@ static void pool2d_nchw_kernel( static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, const int ky, const int kx_padded, queue_ptr stream) { - const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; - const sycl::range<3> num_blocks(1, ky, block_num_x); - int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; - static_assert(QK8_1 % WARP_SIZE == 0); - const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + std::cout << "Hey I am here" << std::endl; + if (g_ggml_sycl_disable_optimize == 0) { + std::cout << "hey here I am tada" << std::endl; + auto local_range = std::size_t(WARP_SIZE); + auto num_quant_blocks = ky * (kx / QK8_1); + auto global_range = num_quant_blocks * local_range; + // since we reorder in the same pointer. + auto quant_block_ptr = (int8_t *) vy; + auto ds_ptr = (sycl::half2 *) ((char *) (vy) + num_quant_blocks * QK8_1 * sizeof(int8_t)); + stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), + [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_and_reorder_q8_1(x, quant_block_ptr, ds_ptr, it); + }); + } else { + const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; + const sycl::range<3> num_blocks(1, ky, block_num_x); + int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; + static_assert(QK8_1 % WARP_SIZE == 0); + const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - stream->parallel_for( - sycl::nd_range<3>(num_blocks * block_size, block_size), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - quantize_q8_1(x, vy, kx, kx_padded, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_q8_1(x, vy, kx, kx_padded, item_ct1); + }); + } } } @@ -2422,6 +2476,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten if (src1_on_device && src1_is_contiguous) { quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + //dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), (ggml_nelements / QK8_1) * sizeof(block_q8_1)); /* DPCT1010:90: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 23eeb74da0d84..96ae2fea861f4 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1,4 +1,5 @@ #include "mmvq.hpp" +#include #include "ggml.h" #include "common.hpp" @@ -40,13 +41,15 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r // Y block index that aligns with ibx const int iby = i * block_type::block_to_q8_1_ratio(); + const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; + sycl::half2 q8_1_ds_ptr = *(sycl::half2*)((char*)vy + ncols + iby * sizeof(sycl::half2)); #pragma unroll for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { // x block quant index when casting the quants to int const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); - partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks); + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks); } } @@ -540,6 +543,9 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); constexpr size_t num_subgroups = 16; + // std::cout << "Hey I am in " << __func__ << std::endl; + // std::cout << "nrows: " << nrows << " ncols: " << ncols << std::endl; + // std::cout << "=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=" << std::endl; GGML_ASSERT(block_num_y % num_subgroups == 0); const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); @@ -1024,6 +1030,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens // nrows_dst == nrows of the matrix that the kernel writes into for (int i = 0; i < src1_ncols; i++) { + // std::cout << "Hey I am launching a kernel ! " << std::endl;; const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; @@ -1101,6 +1108,8 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_ABORT("fatal error"); } } + //std::cout << row_low << " " << row_high << " " << src1_ncols << std::endl; + //std::cout << "=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=" << std::endl; GGML_UNUSED(src1); GGML_UNUSED(dst); GGML_UNUSED(src1_ddf_i); diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index ed3699313466b..848d93939ade4 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl { } __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) { + const int8_t* q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int /* nblocks */) { const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); int v[q4_0_traits::vdr_mmvq]; @@ -295,11 +295,11 @@ template <> struct reorder_vec_dot_q_sycl { for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { v[i] = get_int_from_uint8(bq4_0, iqs + i); - u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi); + u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi); } - return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds); + return vec_dot_q4_0_q8_1_impl(v, u, d, q8_1_ds); }; }; @@ -347,20 +347,20 @@ template <> struct reorder_vec_dot_q_sycl { using q4_k_traits = typename q4_k_block::traits; float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) { - const int ib = ibx_offset / (QK_K / 2); + const int8_t* q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int nblocks) { + // const int ib = ibx_offset / (QK_K / 2); - const uint8_t * base = static_cast(vbq); - const uint8_t * qs = base + ibx_offset; - const int total_qs_bytes = nblocks * (QK_K / 2); - const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE; - const ggml_half2 * dms = reinterpret_cast(base + d_offset); + // const uint8_t * base = static_cast(vbq); + // const uint8_t * qs = base + ibx_offset; + // const int total_qs_bytes = nblocks * (QK_K / 2); + // const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE; + // const ggml_half2 * dms = reinterpret_cast(base + d_offset); - const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); - const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); - const uint16_t * scales = (const uint16_t *) scs; + // const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + // const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + // const uint16_t * scales = (const uint16_t *) scs; - return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); + // return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); } }; From 6096ff80a03529fe9b6af6840e59ffc25e2d174b Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Mon, 26 May 2025 16:12:34 +0100 Subject: [PATCH 2/7] wip2: fuse q8 quantization and reorder --- ggml/src/ggml-sycl/ggml-sycl.cpp | 79 +++++++++++++++++++++----------- ggml/src/ggml-sycl/mmvq.cpp | 8 +--- ggml/src/ggml-sycl/vecdotq.hpp | 22 ++++----- 3 files changed, 63 insertions(+), 46 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 0e40b6464c8ed..195ce73c076c7 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1415,16 +1415,42 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, } template -static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, int8_t * quant_ptr, - sycl::half2 * ds_ptr, const sycl::nd_item<1> & it) { +static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor, + const int kx, const int kx_padded, const sycl::nd_item<1> & it) { + /* + quantize and reorders the resultant q8 tensor in a per row fashion + Each sub-group calculates one quant block + work_group_size = sub_group_size; + + |------------------------------ Matrix Pitch -------------------------| + |------- Matrix Width --------| + q_00 q_01 q_02 ..... q_0n-1 q_n ds00 ds01 ... ds0n/32 ... padding ... | + . . | + . . | + . . Matrix Height + . . | + . . | + q_n0 q_n1 q_n2 ..... q_nn-1 q_n dsn0 dsn1 ... dsnn/32 ... padding ... | + */ + auto subgroup_id = it.get_group(0); auto wi_id = it.get_local_id(0); + const int num_blocks_per_row = kx / QK8_1; + auto row = subgroup_id / num_blocks_per_row; + auto col = subgroup_id % num_blocks_per_row; + + auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1); + auto col_offset = QK8_1 * col + wi_id * ElementsPerWI; + + auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset); + auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2)); + sycl::vec wi_f32_vals; sycl::vec quantized_values; auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; - wi_f32_vals = *reinterpret_cast*>(x + float_ptr_offset); + wi_f32_vals = *reinterpret_cast *>(x + float_ptr_offset); float sum = 0.0f; float amax = 0.0f; @@ -1435,22 +1461,21 @@ static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); quantized_values[i] = 0; } - sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus()); - amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum()); - const float d = amax == 0 ? 1 : amax / 127; + sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus()); + amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum()); + float d = amax == 0 ? 1 : amax / 127; #pragma unroll(ElementsPerWI) for (int i = 0; i < ElementsPerWI; i++) { quantized_values[i] = sycl::round(wi_f32_vals[i] / d); } - *reinterpret_cast*>(quant_ptr + subgroup_id * QK8_1) = quantized_values; - auto my_val = *reinterpret_cast*>(quant_ptr + subgroup_id * QK8_1); - ds_ptr[subgroup_id] = sycl::half2(sycl::half(d), sycl::half(sum)); - auto ds_values = ds_ptr[subgroup_id]; - float sum_value = ds_values[0]; - float d_value = ds_values[1]; - //sycl::ext::oneapi::experimental::printf("%d %d %f %f \n", static_cast(subgroup_id), my_val[0], sum_value, d_value); + d = amax == 0 ? 0 : d; + + *reinterpret_cast *>(quant_ptr) = quantized_values; + if (wi_id == 0) { + *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum)); + } } static void mul_mat_p021_f16_f32( @@ -1737,24 +1762,18 @@ static void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } -static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, - const int ky, const int kx_padded, - queue_ptr stream) { - std::cout << "Hey I am here" << std::endl; - if (g_ggml_sycl_disable_optimize == 0) { - std::cout << "hey here I am tada" << std::endl; +static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded, + bool reorder_q8_tensor, queue_ptr stream) { + if (reorder_q8_tensor) { auto local_range = std::size_t(WARP_SIZE); auto num_quant_blocks = ky * (kx / QK8_1); auto global_range = num_quant_blocks * local_range; - // since we reorder in the same pointer. - auto quant_block_ptr = (int8_t *) vy; - auto ds_ptr = (sycl::half2 *) ((char *) (vy) + num_quant_blocks * QK8_1 * sizeof(int8_t)); stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), - [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - quantize_and_reorder_q8_1(x, quant_block_ptr, ds_ptr, it); - }); + [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_and_reorder_q8_1(x, vy, kx, kx_padded, it); + }); } else { - const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; + const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; const sycl::range<3> num_blocks(1, ky, block_num_x); int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; static_assert(QK8_1 % WARP_SIZE == 0); @@ -2475,7 +2494,11 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + bool reorder_q8_tensor = false; + if (src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder) { + reorder_q8_tensor = true; + } + quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream); //dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), (ggml_nelements / QK8_1) * sizeof(block_q8_1)); /* DPCT1010:90: SYCL uses exceptions to report errors and does not @@ -2580,7 +2603,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten } if (convert_src1_to_q8_1 && !src1_is_contiguous) { - quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream); /* DPCT1010:92: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 96ae2fea861f4..2930a0a262df1 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -30,8 +30,6 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r static_assert(blocks_per_subgroup > 0); static_assert(block_elements_per_subgroup > 0); - const block_q8_1 * y = (const block_q8_1 *) vy; - float partial_sum = 0.0f; for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { const int ibx = row * blocks_per_row + i; // x block index @@ -42,7 +40,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r // Y block index that aligns with ibx const int iby = i * block_type::block_to_q8_1_ratio(); const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; - sycl::half2 q8_1_ds_ptr = *(sycl::half2*)((char*)vy + ncols + iby * sizeof(sycl::half2)); + sycl::half2 q8_1_ds_ptr = *(const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); #pragma unroll for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { @@ -543,9 +541,6 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); constexpr size_t num_subgroups = 16; - // std::cout << "Hey I am in " << __func__ << std::endl; - // std::cout << "nrows: " << nrows << " ncols: " << ncols << std::endl; - // std::cout << "=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=" << std::endl; GGML_ASSERT(block_num_y % num_subgroups == 0); const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); @@ -1030,7 +1025,6 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens // nrows_dst == nrows of the matrix that the kernel writes into for (int i = 0; i < src1_ncols; i++) { - // std::cout << "Hey I am launching a kernel ! " << std::endl;; const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 848d93939ade4..331b573607e4c 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -291,8 +291,8 @@ template <> struct reorder_vec_dot_q_sycl { int v[q4_0_traits::vdr_mmvq]; int u[2 * q4_0_traits::vdr_mmvq]; -#pragma unroll +#pragma unroll for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { v[i] = get_int_from_uint8(bq4_0, iqs + i); u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); @@ -348,19 +348,19 @@ template <> struct reorder_vec_dot_q_sycl { float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, const int8_t* q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int nblocks) { - // const int ib = ibx_offset / (QK_K / 2); + const int ib = ibx_offset / (QK_K / 2); - // const uint8_t * base = static_cast(vbq); - // const uint8_t * qs = base + ibx_offset; - // const int total_qs_bytes = nblocks * (QK_K / 2); - // const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE; - // const ggml_half2 * dms = reinterpret_cast(base + d_offset); + const uint8_t * base = static_cast(vbq); + const uint8_t * qs = base + ibx_offset; + const int total_qs_bytes = nblocks * (QK_K / 2); + const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE; + const ggml_half2 * dms = reinterpret_cast(base + d_offset); - // const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); - // const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); - // const uint16_t * scales = (const uint16_t *) scs; + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) scs; - // return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); + return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); } }; From acd80eca635921907fe0dca4d119c90f1ff051e9 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 27 May 2025 10:29:41 +0100 Subject: [PATCH 3/7] working q8 reorder commit --- ggml/src/ggml-sycl/ggml-sycl.cpp | 15 ++---------- ggml/src/ggml-sycl/mmvq.cpp | 2 +- ggml/src/ggml-sycl/vecdotq.hpp | 39 ++++++++++++++++++++++++++++---- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 195ce73c076c7..1abc530dffb22 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1418,19 +1418,8 @@ template static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor, const int kx, const int kx_padded, const sycl::nd_item<1> & it) { /* - quantize and reorders the resultant q8 tensor in a per row fashion - Each sub-group calculates one quant block - work_group_size = sub_group_size; - - |------------------------------ Matrix Pitch -------------------------| - |------- Matrix Width --------| - q_00 q_01 q_02 ..... q_0n-1 q_n ds00 ds01 ... ds0n/32 ... padding ... | - . . | - . . | - . . Matrix Height - . . | - . . | - q_n0 q_n1 q_n2 ..... q_nn-1 q_n dsn0 dsn1 ... dsnn/32 ... padding ... | + Quantizes and reorders the resultant q8 tensor in a per row fashion + Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values */ auto subgroup_id = it.get_group(0); diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 2930a0a262df1..d579b1f995553 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -40,7 +40,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r // Y block index that aligns with ibx const int iby = i * block_type::block_to_q8_1_ratio(); const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; - sycl::half2 q8_1_ds_ptr = *(const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); + const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); #pragma unroll for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 331b573607e4c..fa258e4d4d106 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl { } __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const int8_t* q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int /* nblocks */) { + const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) { const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); int v[q4_0_traits::vdr_mmvq]; @@ -299,7 +299,7 @@ template <> struct reorder_vec_dot_q_sycl { u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi); } - return vec_dot_q4_0_q8_1_impl(v, u, d, q8_1_ds); + return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds); }; }; @@ -347,7 +347,7 @@ template <> struct reorder_vec_dot_q_sycl { using q4_k_traits = typename q4_k_block::traits; float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const int8_t* q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int nblocks) { + const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) { const int ib = ibx_offset / (QK_K / 2); const uint8_t * base = static_cast(vbq); @@ -360,7 +360,38 @@ template <> struct reorder_vec_dot_q_sycl { const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); const uint16_t * scales = (const uint16_t *) scs; - return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + v[0] = q4[0]; + v[1] = q4[4]; + + uint16_t aux[2]; + const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i); + + d8[i] = ds_values[0]; + + const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8); } }; From 03bd1a6cd9cef8ca9748f1ae08c121141f9554af Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 27 May 2025 10:31:53 +0100 Subject: [PATCH 4/7] restored common.hpp --- ggml/src/ggml-sycl/common.hpp | 123 ++++++++++++++++++++++++++++++++-- 1 file changed, 117 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 618422099082e..15ee9dc69d149 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -13,8 +13,10 @@ #ifndef GGML_SYCL_COMMON_HPP #define GGML_SYCL_COMMON_HPP +#include #include #include +#include #include "dpct/helper.hpp" #include "ggml-sycl.h" @@ -44,11 +46,20 @@ extern int g_ggml_sycl_debug; extern int g_ggml_sycl_disable_optimize; extern int g_ggml_sycl_prioritize_dmmv; -#define GGML_SYCL_DEBUG(...) \ - do { \ - if (g_ggml_sycl_debug) \ - fprintf(stderr, __VA_ARGS__); \ - } while (0) +#if defined(__clang__) && __has_builtin(__builtin_expect) +// Hint the optimizer to pipeline the more likely following instruction in branches +# define LIKELY(expr) __builtin_expect(expr, true) +# define UNLIKELY(expr) __builtin_expect(expr, false) +#else +# define LIKELY(expr) (expr) +# define UNLIKELY(expr) (expr) +#endif + +#define GGML_SYCL_DEBUG(...) \ + do { \ + if (UNLIKELY(g_ggml_sycl_debug)) \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) #define CHECK_TRY_ERROR(expr) \ [&]() { \ @@ -280,7 +291,22 @@ void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector str inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) { optimize_feature opt; - opt.reorder = true; + opt.reorder = + (arch == syclex::architecture::intel_gpu_dg1 || + arch == syclex::architecture::intel_gpu_acm_g10 || + arch == syclex::architecture::intel_gpu_acm_g11 || + arch == syclex::architecture::intel_gpu_acm_g12 || + arch == syclex::architecture::intel_gpu_pvc || + arch == syclex::architecture::intel_gpu_pvc_vg || + arch == syclex::architecture::intel_gpu_mtl_u || + arch == syclex::architecture::intel_gpu_mtl_s || + arch == syclex::architecture::intel_gpu_mtl_h || + arch == syclex::architecture::intel_gpu_arl_u || + arch == syclex::architecture::intel_gpu_arl_s || + arch == syclex::architecture::intel_gpu_arl_h || + arch == syclex::architecture::intel_gpu_bmg_g21 || + arch == syclex::architecture::intel_gpu_lnl_m + ); return opt; } @@ -456,6 +482,19 @@ static __dpct_inline__ float warp_reduce_max(float x, return x; } +/* Helper for Computing the linear offset of a ggml_tensor given +per-dimension sizes, strides, and indices */ +template +__dpct_inline__ size_t calculate_offset(const std::array & strides, const std::array & indices) { + size_t offset = 0; +#pragma unroll + for (int i = 0; i < N; i++) { + auto index_i = indices[i]; + offset += strides[i] * index_i; + } + return offset; +} + // Helper for vec loading aligned data template inline sycl::vec vec_aligned_load(const Tp* aligned_ptr) { @@ -475,4 +514,76 @@ constexpr size_t ceil_div(const size_t m, const size_t n) { } bool gpu_has_xmx(sycl::device &dev); + +template void debug_print_array(const std::string & prefix, const T array[N]) { + if (LIKELY(!g_ggml_sycl_debug)) { + return; + } + std::stringstream ss; + ss << prefix << "=["; + for (std::size_t i = 0; i < N - 1; ++i) { + ss << array[i] << ", "; + } + if constexpr (N > 0) { + ss << array[N - 1]; + } + ss << "]"; + GGML_SYCL_DEBUG("%s", ss.str().c_str()); +} + +inline void debug_print_tensor(const std::string & prefix, const ggml_tensor * tensor, + const std::string & suffix = "") { + if (LIKELY(!g_ggml_sycl_debug)) { + return; + } + GGML_SYCL_DEBUG("%s=", prefix.c_str()); + if (tensor) { + GGML_SYCL_DEBUG("'%s':type=%s", tensor->name, ggml_type_name(tensor->type)); + debug_print_array(";ne", tensor->ne); + debug_print_array(";nb", tensor->nb); + if (!ggml_is_contiguous(tensor)) { + GGML_SYCL_DEBUG(";strided"); + } + if (ggml_is_permuted(tensor)) { + GGML_SYCL_DEBUG(";permuted"); + } + } else { + GGML_SYCL_DEBUG("nullptr"); + } + GGML_SYCL_DEBUG("%s", suffix.c_str()); +} + +// Use scope_op_debug_print to log operations coming from running a model +struct scope_op_debug_print { + // Use string_views to avoid the cost of creating a string and concatenating them + // string_views must be alive for as long as the object is alive + // scope_op_debug_print are used with string literals in practice which are stored in constant space so always accessible + scope_op_debug_print(const std::string_view & func, const std::string_view & func_suffix, const ggml_tensor * dst, + std::size_t num_src, const std::string_view & suffix = "") : + func(func), + func_suffix(func_suffix) { + if (LIKELY(!g_ggml_sycl_debug)) { + return; + } + GGML_SYCL_DEBUG("[SYCL][OP] call %s%s:", func.data(), func_suffix.data()); + debug_print_tensor(" dst", dst); + if (dst) { + for (std::size_t i = 0; i < num_src; ++i) { + debug_print_tensor("\tsrc" + std::to_string(i), dst->src[i]); + } + } + GGML_SYCL_DEBUG("%s\n", suffix.data()); + } + + scope_op_debug_print(const std::string_view & func, const ggml_tensor * dst, std::size_t num_src, + const std::string_view & suffix = "") : + scope_op_debug_print(func, "", dst, num_src, suffix) {} + + ~scope_op_debug_print() { GGML_SYCL_DEBUG("[SYCL][OP] call %s%s done\n", func.data(), func_suffix.data()); } + + private: + std::string_view func; + std::string_view func_suffix; +}; + #endif // GGML_SYCL_COMMON_HPP From ade12bf596d1cf4f30a1d1fa803989c5ba99011b Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 27 May 2025 11:53:41 +0100 Subject: [PATCH 5/7] remove debug prints --- ggml/src/ggml-sycl/mmvq.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index c9957522004e8..ec5f67f6283af 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1104,8 +1104,6 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_ABORT("fatal error"); } } - //std::cout << row_low << " " << row_high << " " << src1_ncols << std::endl; - //std::cout << "=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=" << std::endl; GGML_UNUSED(src1); GGML_UNUSED(dst); GGML_UNUSED(src1_ddf_i); From 79eede6ce87ccb86e9098bbd436b24128931f628 Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Thu, 29 May 2025 10:46:19 +0100 Subject: [PATCH 6/7] remove unnecessary headers and remove trailing whitespace --- ggml/src/ggml-sycl/ggml-sycl.cpp | 2 -- ggml/src/ggml-sycl/mmvq.cpp | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 80bcab498b842..942832f1b9167 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -33,7 +33,6 @@ #include #include "ggml-sycl.h" -#include "common.hpp" #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -45,7 +44,6 @@ #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" #include "ggml.h" -#include "presets.hpp" static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index ec5f67f6283af..80c780b209998 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1,5 +1,4 @@ #include "mmvq.hpp" -#include #include "ggml.h" #include "common.hpp" @@ -40,7 +39,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r // Y block index that aligns with ibx const int iby = i * block_type::block_to_q8_1_ratio(); const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; - const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); + const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); #pragma unroll for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { From 5f8bc74377ed92f2ed9b6fc6818272be9b3888a9 Mon Sep 17 00:00:00 2001 From: Atharva Dubey Date: Thu, 29 May 2025 10:46:55 +0100 Subject: [PATCH 7/7] Update ggml/src/ggml-sycl/ggml-sycl.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Alberto Cabrera Pérez --- ggml/src/ggml-sycl/ggml-sycl.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 942832f1b9167..5840ef7f45c54 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2506,10 +2506,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); if (src1_on_device && src1_is_contiguous) { - bool reorder_q8_tensor = false; - if (src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder) { - reorder_q8_tensor = true; - } + bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder; scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, /*num_src=*/2, " : converting src1 to Q8_1"); quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);