diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d450fe9a2f2f6..364c14ff2425a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -149,6 +149,7 @@ class vk_perf_logger; static void ggml_vk_destroy_buffer(vk_buffer& buf); static constexpr uint32_t mul_mat_vec_max_cols = 8; +static constexpr uint32_t p021_max_gqa_ratio = 8; enum vk_device_architecture { OTHER, @@ -231,6 +232,7 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; + bool subgroup_add; bool subgroup_size_control; uint32_t subgroup_min_size; @@ -277,7 +279,7 @@ struct vk_device_struct { vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; - vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; @@ -2265,7 +2267,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { + if (device->subgroup_add && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + } + } ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2471,13 +2479,15 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceDriverProperties driver_props; vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan11Properties vk11_props; vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; props2.pNext = &props3; props3.pNext = &subgroup_props; subgroup_props.pNext = &driver_props; - driver_props.pNext = &vk12_props; + driver_props.pNext = &vk11_props; + vk11_props.pNext = &vk12_props; VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; @@ -2541,6 +2551,9 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -4627,9 +4640,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t d_sz = sizeof(float) * d_ne; + // With grouped query attention there are > 1 Q matrices per K, V matrix. + uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02; + if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) { + gqa_ratio = 1; + } + if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); return; } @@ -4653,8 +4672,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c // compute const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + + uint32_t workgroups_z = (uint32_t)ne12; + // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups + if (gqa_ratio > 1) { + workgroups_z /= gqa_ratio; + } + ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z }); } static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp index 1cc4996d393a2..48376637fb3e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -12,6 +12,9 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + layout (push_constant) uniform parameter { uint ncols_x; @@ -37,25 +40,66 @@ void main() { const uint idst = channel*nrows_dst + row_dst; - tmp[tid] = 0.0f; + FLOAT_TYPE temp = 0.0f; - for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0; - if (col_x >= p.ncols_x) { - break; - } + for (uint col_x0 = 0; col_x0 < p.ncols_x;) { + + // Unroll 2x and do vec4 loads if aligned + const uint unroll_count = 2; + if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) { + [[unroll]] for (uint i = 0; i < unroll_count; ++i) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel*nrows_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } + // do vec4 loads if aligned + } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + const uint col_x = col_x0 + 4*tid; - const uint row_y = col_x; + const uint row_y = col_x; - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel*nrows_y + row_y; + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel*nrows_y + row_y; - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); - tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel*nrows_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp); + col_x0 += BLOCK_SIZE; + } } + tmp[tid] = temp; + // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp index 9b443807d8781..7aa070eebdf72 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -2,16 +2,25 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif -#define BLOCK_SIZE 32 #define FLOAT_TYPE float -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout(constant_id = 0) const int BLOCK_SIZE = 32; +// gqa_ratio is in the range [1,8] +layout(constant_id = 1) const uint gqa_ratio = 1; + layout (push_constant) uniform parameter { uint ncols_x; @@ -22,52 +31,124 @@ layout (push_constant) uniform parameter uint d_offset; } p; -shared FLOAT_TYPE tmp[BLOCK_SIZE]; +#if !USE_SUBGROUP_ADD +shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; +#endif void main() { const uint tid = gl_LocalInvocationID.x; const uint row_x = gl_GlobalInvocationID.y; - const uint channel = gl_GlobalInvocationID.z; - const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); + + uint channel, channel_x; + + // When gqa_ratio > 1, each invocation does multiple rows. + // The row in the A matrix is starting from channel / gqa_ratio and the + // rows in the B matrix are [channel, channel+gqa_ratio). + // When gpa_ratio is 1, each invocation does one row. + if (gqa_ratio > 1) { + channel_x = gl_GlobalInvocationID.z; + channel = channel_x * gqa_ratio; + } else { + channel = gl_GlobalInvocationID.z; + channel_x = channel / (p.nchannels_y / p.nchannels_x);; + } const uint nrows_y = p.ncols_x; const uint nrows_dst = p.nrows_x; const uint row_dst = row_x; - tmp[tid] = FLOAT_TYPE(0.0f); + FLOAT_TYPE temp[8]; + [[unroll]] for (uint i = 0; i < 8; ++i) { + temp[i] = FLOAT_TYPE(0.0f); + } + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; - if (col_x >= p.ncols_x) { - break; - } + // Use vec4 loads if aligned + if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { - // x is transposed and permuted - const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + uint col_x = col_x0 + 4*tid; + const uint row_y = col_x; - const uint row_y = col_x; + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const vec4 av4 = vec4(data_a_v4[ix / 4]); - // y is not transposed but permuted - const uint iy = channel*nrows_y + row_y; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; - tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); - } + vec4 bv4 = data_b_v4[iy / 4]; + temp[c] += dot(av4, bv4); + } + + col_x0 += 3*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } - // dst is not transposed and not permuted - const uint idst = channel*nrows_dst + row_dst; + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + const uint row_y = col_x; + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); + } + } + } + +#if USE_SUBGROUP_ADD + // reduce vec4 at a time + vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); + t = subgroupAdd(t); + temp[0] = t[0]; + temp[1] = t[1]; + temp[2] = t[2]; + temp[3] = t[3]; + if (gqa_ratio > 4) { + t = vec4(temp[4], temp[5], temp[6], temp[7]); + t = subgroupAdd(t); + temp[4] = t[0]; + temp[5] = t[1]; + temp[6] = t[2]; + temp[7] = t[3]; + } +#else + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + tmp[c][tid] = temp[c]; + } // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { - tmp[tid] += tmp[tid + s]; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] += tmp[c][tid + s]; + tmp[c][tid] = temp[c]; + } } barrier(); } + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] = tmp[c][tid]; + } +#endif if (tid == 0) { - dst[idst] = tmp[0]; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // dst is not transposed and not permuted + const uint idst = (channel + c)*nrows_dst + row_dst; + dst[idst] = temp[c]; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index eb2ad63ff6bf0..d883eda045c68 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -426,8 +426,9 @@ void process_shaders() { } } - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d48cd21723155..9d7847d21b27f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1964,9 +1964,10 @@ struct test_mul_mat : public test_case { const std::array bs; // dims 3 and 4 const std::array nr; // repeat in dims 3 and 4 const std::array per; // permutation of dimensions + const bool v; // whether a is a non-contiguous view std::string vars() override { - return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per); + return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v); } double max_nmse_err() override { @@ -1986,8 +1987,9 @@ struct test_mul_mat : public test_case { int64_t m = 32, int64_t n = 32, int64_t k = 32, std::array bs = {10, 10}, std::array nr = {2, 2}, - std::array per = {0, 1, 2, 3}) - : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {} + std::array per = {0, 1, 2, 3}, + bool v = false) + : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) @@ -1997,6 +1999,7 @@ struct test_mul_mat : public test_case { const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3); if (npermuted > 0) { GGML_ASSERT(npermuted == 2); + GGML_ASSERT(!v); // not handled GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0); GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0); @@ -2020,7 +2023,13 @@ struct test_mul_mat : public test_case { ggml_set_name(a, "a_permuted"); ggml_set_name(b, "b_permuted"); } else { - a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + + if (v) { + a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]); + a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0); + } else { + a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + } b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); if (!ggml_is_quantized(type_a)) { if (bs[1] == 1 && nr[1] == 1) { @@ -4176,6 +4185,17 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1})); + for (auto bs : {1,2,4,8}) { + for (auto nr : {1,4}) { + for (uint32_t m = 0; m < 2; ++m) { + for (uint32_t k = 0; k < 2; ++k) { + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true)); + } + } + } + } + // sycl backend will limit task global_range < MAX_INT // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion) // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.) @@ -4444,6 +4464,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true)); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) {