From 44fae395e0e2967507a70e95fce9b7da59daa297 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 17 Dec 2024 10:23:07 +0000 Subject: [PATCH 1/8] Grouped gemm simple code refactor --- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 3 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 8 +-- .../run_grouped_gemm_example.inc | 20 ++++---- example/ck_tile/17_grouped_gemm/utils.hpp | 38 -------------- include/ck_tile/core.hpp | 1 - include/ck_tile/core/arch/arch.hpp | 26 +++++++++- .../core/utility/amd_address_space.hpp | 37 -------------- include/ck_tile/host/host_tensor.hpp | 33 +++++++++++++ include/ck_tile/ops/gemm.hpp | 1 + .../ops/gemm/kernel/gemm_offset_block.hpp | 49 +++++++++++++++++++ .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 8 ++- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 35 +++++++++---- 12 files changed, 151 insertions(+), 108 deletions(-) delete mode 100644 example/ck_tile/17_grouped_gemm/utils.hpp delete mode 100644 include/ck_tile/core/utility/amd_address_space.hpp create mode 100644 include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 14f3b4a5b8..6b51f696a3 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -15,7 +15,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -#include "utils.hpp" namespace { @@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel>; }; // namespace -std::size_t GetWorkspaceSize(const std::vector& gemm_descs) +std::size_t get_workspace_size(const std::vector& gemm_descs) { return ::Kernel::GetWorkSpaceSize(gemm_descs); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 94af4711d1..db1cbb0673 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -46,8 +46,8 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -std::size_t GetWorkspaceSize(const std::vector& gemm_descs); +std::size_t get_workspace_size(const std::vector& gemm_descs); -float grouped_gemm_calc(const std::vector& gemm_descs, - const ck_tile::stream_config& s, - void* p_workspace_); +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* p_workspace_); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index cd5b1c2864..67d7856040 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -11,7 +11,7 @@ float invoke_gemm(int n_warmup, { ck_tile::DeviceMem gemm_workspace; - gemm_workspace.Realloc(GetWorkspaceSize(args)); + gemm_workspace.Realloc(get_workspace_size(args)); float ave_time = grouped_gemm( args, @@ -100,16 +100,16 @@ int run_grouped_gemm_example_with_layouts(int argc, const ck_tile::index_t N = Ns[i]; const ck_tile::index_t K = Ks[i]; - stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout); - stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout); - stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{}); + stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{}); - a_m_k_tensors.push_back( - ck_tile::HostTensor(f_host_tensor_descriptor(M, K, stride_As[i], a_layout))); - b_k_n_tensors.push_back( - ck_tile::HostTensor(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); c_m_n_tensors.push_back(ck_tile::HostTensor( - f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc @@ -150,7 +150,7 @@ int run_grouped_gemm_example_with_layouts(int argc, for(int i = 0; i < group_count; ++i) { ck_tile::HostTensor c_m_n_host_ref( - f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); + ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); diff --git a/example/ck_tile/17_grouped_gemm/utils.hpp b/example/ck_tile/17_grouped_gemm/utils.hpp deleted file mode 100644 index bb3cdf9fdc..0000000000 --- a/example/ck_tile/17_grouped_gemm/utils.hpp +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -template -constexpr auto -f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout) -{ - using namespace ck_tile::literals; - - if constexpr(std::is_same_v) - { - return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); - } -} -template -constexpr auto -f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout) -{ - if(stride == 0) - { - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; -} diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 41f3383c7f..3cf0c2595d 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,7 +54,6 @@ #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/update_tile.hpp" -#include "ck_tile/core/utility/amd_address_space.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp" diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index afcf982a63..d82029dd79 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -109,4 +109,28 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) #endif } +#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) + +template +__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) +{ + // cast a pointer in "Constant" address space (4) to "Generic" address space (0) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +template +__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) +{ + // cast a pointer in "Generic" address space (0) to "Constant" address space (4) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + } // namespace ck_tile diff --git a/include/ck_tile/core/utility/amd_address_space.hpp b/include/ck_tile/core/utility/amd_address_space.hpp deleted file mode 100644 index cb242bf0d5..0000000000 --- a/include/ck_tile/core/utility/amd_address_space.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core/config.hpp" - -// Address Space for AMDGCN -// https://llvm.org/docs/AMDGPUUsage.html#address-space - -namespace ck_tile { - -#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) - -template -__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) -{ - // cast a pointer in "Constant" address space (4) to "Generic" address space (0) - // only c-style pointer cast seems be able to be compiled -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wold-style-cast" - return (T*)p; // NOLINT(old-style-cast) -#pragma clang diagnostic pop -} - -template -__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) -{ - // cast a pointer in "Generic" address space (0) to "Constant" address space (4) - // only c-style pointer cast seems be able to be compiled -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wold-style-cast" - return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) -#pragma clang diagnostic pop -} - -} // namespace ck_tile diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 3902cad178..10313e4207 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -678,4 +678,37 @@ struct HostTensor Descriptor mDesc; Data mData; }; + +template +auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout) +{ + using namespace ck_tile::literals; + + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } +} +template +auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout) +{ + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 2d38ef5925..6d7dc50806 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -26,6 +26,7 @@ #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_offset_block.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp b/include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp new file mode 100644 index 0000000000..50c13efb55 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +template +struct OffsettedBlockToCTileMap +{ + using tile_partitioner_type = TilePartitioner_; + + __host__ __device__ OffsettedBlockToCTileMap(ck_tile::index_t B2CkTileMap, + ck_tile::index_t M, + ck_tile::index_t N) + : B2CkTileMap_{B2CkTileMap}, M_{M}, N_{N} + { + } + + __host__ __device__ constexpr auto CalculateBottomIndex(const ck_tile::index_t idx_top) const + { + ck_tile::index_t block_1d_id = idx_top; + + const auto M0 = ck_tile::integer_divide_ceil(M_, tile_partitioner_type::MPerBlock); + const auto N0 = ck_tile::integer_divide_ceil(N_, tile_partitioner_type::NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); + + block_1d_id = block_1d_id % (M0 * N0); + + ck_tile::index_t idx_N0 = block_1d_id % N0; + ck_tile::index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % B2CkTileMap_) ? B2CkTileMap_ : M0 % B2CkTileMap_; + + ck_tile::index_t idx_M00 = idx_M0 / B2CkTileMap_; + ck_tile::index_t idx_M01 = idx_M0 % B2CkTileMap_; + ck_tile::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * B2CkTileMap_, + idx_N0_M01_local / M01_adapt); + } + + ck_tile::index_t B2CkTileMap_; + ck_tile::index_t M_; + ck_tile::index_t N_; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 8ffe681f90..11909d3e58 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -62,12 +62,10 @@ struct GemmTile1DPartitioner return integer_divide_ceil(K, KPerBlock); } - CK_TILE_DEVICE auto operator()(index_t blockOffset, index_t NBlockSize) + CK_TILE_DEVICE auto operator()() { - index_t iM = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) / - GetNBlock(NBlockSize) * MPerBlock); - index_t iN = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) % - GetNBlock(NBlockSize) * NPerBlock); + index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * MPerBlock); + index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.x * NPerBlock); return make_tuple(iM, iN); } }; diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index f24fc47afc..28ee785951 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -8,7 +8,6 @@ #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/literals.hpp" -#include "ck_tile/core/utility/amd_address_space.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" @@ -32,14 +31,19 @@ struct GroupedGemmHostArgs template struct GroupedGemmKernel { + using Hargs = GroupedGemmHostArgs; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; + using Block2ETileMap = OffsettedBlockToCTileMap; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -47,12 +51,19 @@ struct GroupedGemmKernel struct GemmTransKernelArg { GroupedGemmHostArgs group_karg; + Block2ETileMap block_2_ctile_map_; ck_tile::index_t block_start; ck_tile::index_t block_end; GemmTransKernelArg() = default; - GemmTransKernelArg(GroupedGemmHostArgs&& karg, index_t bl_start, index_t bl_end) - : group_karg{karg}, block_start{bl_start}, block_end{bl_end} + GemmTransKernelArg(GroupedGemmHostArgs&& karg, + Block2ETileMap block_2_ctile_map_karg, + index_t bl_start, + index_t bl_end) + : group_karg{karg}, + block_2_ctile_map_{block_2_ctile_map_karg}, + block_start{bl_start}, + block_end{bl_end} { } }; @@ -64,8 +75,6 @@ struct GroupedGemmKernel __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - using Hargs = GroupedGemmHostArgs; - __host__ static constexpr auto GridSize(const std::vector& gemm_descs) { index_t grid_size = 0; @@ -100,13 +109,15 @@ struct GroupedGemmKernel const index_t stride_c = gemm_descs[i].stride_C; const auto dim3 = TilePartitioner::GridSize(M, N); - const index_t grid_size_grp = dim3.x * 1 * 1; + const index_t grid_size_grp = dim3.x; const index_t block_start = grid_size; const index_t block_end = grid_size + grid_size_grp; grid_size += grid_size_grp; + auto grouped_block_2_ctile_map = Block2ETileMap(B2E_M01, M, N); + auto karg = GroupedGemmHostArgs{type_convert(gemm_descs[i].a_ptr), type_convert(gemm_descs[i].b_ptr), type_convert(gemm_descs[i].c_ptr), @@ -117,7 +128,8 @@ struct GroupedGemmKernel stride_b, stride_c}; - gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); } return gemm_kernel_args_; @@ -128,9 +140,12 @@ struct GroupedGemmKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_DEVICE void Run(const Hargs& kargs, const index_t block_start) const + CK_TILE_DEVICE void Run(const Hargs& kargs, const Block2ETileMap& block_2_tile_map) const { - const auto [i_m, i_n] = TilePartitioner{}(block_start, kargs.N); + const auto [i_M, i_N] = block_2_tile_map.CalculateBottomIndex(ck_tile::get_block_1d_id()); + index_t i_m = __builtin_amdgcn_readfirstlane(i_M * TilePartitioner::MPerBlock); + index_t i_n = __builtin_amdgcn_readfirstlane(i_N * TilePartitioner::NPerBlock); + // options const ADataType* a_start = static_cast(kargs.a_ptr); const BDataType* b_start = static_cast(kargs.b_ptr); @@ -303,7 +318,7 @@ struct GroupedGemmKernel group_id = index_t((left + right) / 2); } - Run(gemm_desc_ptr[group_id].group_karg, gemm_desc_ptr[group_id].block_start); + Run(gemm_desc_ptr[group_id].group_karg, gemm_desc_ptr[group_id].block_2_ctile_map_); } }; From dc48a147cb9f6c1aa950234273eafb6f6fb0c5e3 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Sun, 5 Jan 2025 16:24:49 +0000 Subject: [PATCH 2/8] Offset invoker --- include/ck_tile/core/arch/arch.hpp | 43 +++++++--- include/ck_tile/ops/gemm.hpp | 3 +- .../ops/gemm/kernel/gemm_offset_block.hpp | 49 ----------- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 49 +++++++++-- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 81 ++++++++----------- 5 files changed, 110 insertions(+), 115 deletions(-) delete mode 100644 include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index d82029dd79..6556f4a8fb 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,18 +12,37 @@ namespace ck_tile { -enum struct address_space_enum +template +struct safe_underlying_type; + +template +struct safe_underlying_type { - generic, + using type = std::underlying_type_t; +}; + +template +struct safe_underlying_type +{ + using type = void; +}; + +template +using safe_underlying_type_t = typename safe_underlying_type::value>::type; + +enum struct address_space_enum : std::uint8_t +{ + generic = 0, global, lds, sgpr, - vgpr, + constant, + vgpr }; -enum struct memory_operation_enum +enum struct memory_operation_enum : std::uint8_t { - set, + set = 0, atomic_add, atomic_max, add @@ -109,7 +128,13 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) #endif } -#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) +#define CK_CONSTANT_ADDRESS_SPACE \ + __attribute__((address_space( \ + static_cast>(address_space_enum::constant)))) + +#define CK_GENERIC_ADDRESS_SPACE \ + __attribute__((address_space( \ + static_cast>(address_space_enum::generic)))) template __device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) @@ -118,7 +143,7 @@ __device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* // only c-style pointer cast seems be able to be compiled #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" - return (T*)p; // NOLINT(old-style-cast) + return (T*)(p); // NOLINT(old-style-cast) #pragma clang diagnostic pop } @@ -126,7 +151,7 @@ template __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) { // cast a pointer in "Generic" address space (0) to "Constant" address space (4) - // only c-style pointer cast seems be able to be compiled + // only c-style pointer cast seems be able to be compiled; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6d7dc50806..5bbe0601b7 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,7 +26,6 @@ #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_offset_block.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp b/include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp deleted file mode 100644 index 50c13efb55..0000000000 --- a/include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp +++ /dev/null @@ -1,49 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { -template -struct OffsettedBlockToCTileMap -{ - using tile_partitioner_type = TilePartitioner_; - - __host__ __device__ OffsettedBlockToCTileMap(ck_tile::index_t B2CkTileMap, - ck_tile::index_t M, - ck_tile::index_t N) - : B2CkTileMap_{B2CkTileMap}, M_{M}, N_{N} - { - } - - __host__ __device__ constexpr auto CalculateBottomIndex(const ck_tile::index_t idx_top) const - { - ck_tile::index_t block_1d_id = idx_top; - - const auto M0 = ck_tile::integer_divide_ceil(M_, tile_partitioner_type::MPerBlock); - const auto N0 = ck_tile::integer_divide_ceil(N_, tile_partitioner_type::NPerBlock); - - block_1d_id = block_1d_id % (M0 * N0); - - block_1d_id = block_1d_id % (M0 * N0); - - ck_tile::index_t idx_N0 = block_1d_id % N0; - ck_tile::index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % B2CkTileMap_) ? B2CkTileMap_ : M0 % B2CkTileMap_; - - ck_tile::index_t idx_M00 = idx_M0 / B2CkTileMap_; - ck_tile::index_t idx_M01 = idx_M0 % B2CkTileMap_; - ck_tile::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * B2CkTileMap_, - idx_N0_M01_local / M01_adapt); - } - - ck_tile::index_t B2CkTileMap_; - ck_tile::index_t M_; - ck_tile::index_t N_; -}; -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 11909d3e58..9606ccc181 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" namespace ck_tile { + template struct GemmTilePartitioner { @@ -23,7 +24,7 @@ struct GemmTilePartitioner return dim3(GridDimX, GridDimY, GridDimZ); } - CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) + CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) -> index_t { return integer_divide_ceil(K, kK); } @@ -45,28 +46,60 @@ struct GemmTile1DPartitioner static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N) + CK_TILE_HOST static constexpr auto + GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3 { index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; index_t GridDimY = (N + NPerBlock - 1) / NPerBlock; return dim3(GridDimX * GridDimY, 1, 1); } - CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) + CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) -> index_t { return integer_divide_ceil(N, NPerBlock); } - CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) + CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) -> index_t { return integer_divide_ceil(K, KPerBlock); } - CK_TILE_DEVICE auto operator()() + CK_TILE_DEVICE auto + operator()(index_t blockIdx, index_t NBlockSize) noexcept(noexcept(GetNBlock(NBlockSize) != 0)) + -> tuple { - index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * MPerBlock); - index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.x * NPerBlock); + const index_t NBlock = GetNBlock(NBlockSize); + + const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock); + const index_t iN = __builtin_amdgcn_readfirstlane(fast_mod(blockIdx, NBlock)); + return make_tuple(iM, iN); } + + private: + template + CK_TILE_DEVICE auto fast_mod(const TType input, const TType ceil) noexcept(noexcept(ceil != 0)) + -> std::enable_if_t::is_integer, TType> + { + return input >= ceil ? input - (input / ceil) * ceil : input; + } +}; + +template +struct InvokeOffsetCallculationFor1DPartitioner +{ + template + CK_TILE_DEVICE constexpr auto operator()(TType block_start, TType NBlockSize) + -> const std::enable_if_t::is_integer, tuple> + { + const auto [iM, iN] = PartitionerFn{}(blockIdx.x - block_start, NBlockSize); + const auto iM_to_img_corr = iM * PartitionerFn::MPerBlock; + const auto iN_to_img_corr = iN * PartitionerFn::NPerBlock; + + const TType i_m = __builtin_amdgcn_readfirstlane(iM_to_img_corr); + const TType i_n = __builtin_amdgcn_readfirstlane(iN_to_img_corr); + + return make_tuple(i_m, i_n); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 28ee785951..bf821415f0 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,19 +31,14 @@ struct GroupedGemmHostArgs template struct GroupedGemmKernel { - using Hargs = GroupedGemmHostArgs; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; - using Block2ETileMap = OffsettedBlockToCTileMap; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - // Block2CTileMap configuration parameter. - static constexpr index_t B2E_M01 = 8; - using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -51,31 +46,25 @@ struct GroupedGemmKernel struct GemmTransKernelArg { GroupedGemmHostArgs group_karg; - Block2ETileMap block_2_ctile_map_; ck_tile::index_t block_start; ck_tile::index_t block_end; GemmTransKernelArg() = default; - GemmTransKernelArg(GroupedGemmHostArgs&& karg, - Block2ETileMap block_2_ctile_map_karg, - index_t bl_start, - index_t bl_end) - : group_karg{karg}, - block_2_ctile_map_{block_2_ctile_map_karg}, - block_start{bl_start}, - block_end{bl_end} + GemmTransKernelArg(GroupedGemmHostArgs&& karg, index_t bl_start, index_t bl_end) + : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } }; - __host__ static size_t GetWorkSpaceSize(const std::vector& gemm_descs) + __host__ static auto GetWorkSpaceSize(const std::vector& gemm_descs) + -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } - __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + __host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); } - __host__ static constexpr auto GridSize(const std::vector& gemm_descs) + __host__ static constexpr auto GridSize(const std::vector& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -86,7 +75,8 @@ struct GroupedGemmKernel return dim3(grid_size, 1, 1); } - CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) + CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) + -> std::vector { std::vector gemm_kernel_args_; index_t group_count = ck_tile::type_convert(gemm_descs.size()); @@ -116,8 +106,6 @@ struct GroupedGemmKernel grid_size += grid_size_grp; - auto grouped_block_2_ctile_map = Block2ETileMap(B2E_M01, M, N); - auto karg = GroupedGemmHostArgs{type_convert(gemm_descs[i].a_ptr), type_convert(gemm_descs[i].b_ptr), type_convert(gemm_descs[i].c_ptr), @@ -128,35 +116,34 @@ struct GroupedGemmKernel stride_b, stride_c}; - gemm_kernel_args_.emplace_back( - std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } return gemm_kernel_args_; } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t { return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_DEVICE void Run(const Hargs& kargs, const Block2ETileMap& block_2_tile_map) const + CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const { - const auto [i_M, i_N] = block_2_tile_map.CalculateBottomIndex(ck_tile::get_block_1d_id()); - index_t i_m = __builtin_amdgcn_readfirstlane(i_M * TilePartitioner::MPerBlock); - index_t i_n = __builtin_amdgcn_readfirstlane(i_N * TilePartitioner::NPerBlock); + const auto [i_m, i_n] = InvokeOffsetCallculationFor1DPartitioner{}( + kargs.block_start, kargs.group_karg.N); // options - const ADataType* a_start = static_cast(kargs.a_ptr); - const BDataType* b_start = static_cast(kargs.b_ptr); + const ADataType* a_start = static_cast(kargs.group_karg.a_ptr); + const BDataType* b_start = static_cast(kargs.group_karg.b_ptr); + // Convert pointers to tensor views auto a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_start, - make_tuple(kargs.M, kargs.K), - make_tuple(kargs.stride_A, 1), + make_tuple(kargs.group_karg.M, kargs.group_karg.K), + make_tuple(kargs.group_karg.stride_A, 1), number{}, number<1>{}); } @@ -164,8 +151,8 @@ struct GroupedGemmKernel { return make_naive_tensor_view( a_start, - make_tuple(kargs.M, kargs.K), - make_tuple(1, kargs.stride_A), + make_tuple(kargs.group_karg.M, kargs.group_karg.K), + make_tuple(1, kargs.group_karg.stride_A), number<1>{}, number<1>{}); } @@ -176,8 +163,8 @@ struct GroupedGemmKernel { return make_naive_tensor_view( b_start, - make_tuple(kargs.N, kargs.K), - make_tuple(1, kargs.stride_B), + make_tuple(kargs.group_karg.N, kargs.group_karg.K), + make_tuple(1, kargs.group_karg.stride_B), number<1>{}, number<1>{}); } @@ -185,8 +172,8 @@ struct GroupedGemmKernel { return make_naive_tensor_view( b_start, - make_tuple(kargs.N, kargs.K), - make_tuple(kargs.stride_B, 1), + make_tuple(kargs.group_karg.N, kargs.group_karg.K), + make_tuple(kargs.group_karg.stride_B, 1), number{}, number<1>{}); } @@ -240,20 +227,20 @@ struct GroupedGemmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); + const index_t num_loop = TilePartitioner::GetLoopNum(kargs.group_karg.K); // Run GEMM cooperatively by whole wokrgroup. auto c_block_tile = GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); - CDataType* c_start = static_cast(kargs.c_ptr); + CDataType* c_start = static_cast(kargs.group_karg.c_ptr); auto c_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( c_start, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), + make_tuple(kargs.group_karg.M, kargs.group_karg.N), + make_tuple(kargs.group_karg.stride_C, 1), number{}, number<1>{}); } @@ -261,8 +248,8 @@ struct GroupedGemmKernel { return make_naive_tensor_view( c_start, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), + make_tuple(kargs.group_karg.M, kargs.group_karg.N), + make_tuple(1, kargs.group_karg.stride_C), number<1>{}, number<1>{}); } @@ -293,7 +280,7 @@ struct GroupedGemmKernel } CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - int group_count) const + index_t group_count) const { const index_t block_id = ck_tile::get_block_1d_id(); const auto gemm_desc_ptr = reinterpret_cast( @@ -315,10 +302,10 @@ struct GroupedGemmKernel { left = group_id; } - group_id = index_t((left + right) / 2); + group_id = index_t((left + right) >> 1); } - Run(gemm_desc_ptr[group_id].group_karg, gemm_desc_ptr[group_id].block_2_ctile_map_); + Run(gemm_desc_ptr[group_id]); } }; From ae36a63e08e560eaa3fcc31c2d68b1087d4d3cf5 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Mon, 6 Jan 2025 14:07:26 +0000 Subject: [PATCH 3/8] Invoke generic Run, and replace name of parrtitioner variable --- example/ck_tile/03_gemm/gemm_basic.cpp | 6 +- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 6 +- include/ck_tile/core/arch/arch.hpp | 4 - .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 70 +++--- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 20 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 217 ++++-------------- 6 files changed, 99 insertions(+), 224 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 4c630375f4..49a8f54b7a 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -63,8 +63,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& kOutputRank, 1, 0, - TilePartitioner::kM, - TilePartitioner::kN>>, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock>>, ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem>>; diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index b9c9eaa583..5851ac5c05 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre kOutputRank, 1, 0, - TilePartitioner::kM, - TilePartitioner::kN>>, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock>>, ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem>>; diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 6556f4a8fb..5927ddb41a 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -132,10 +132,6 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) __attribute__((address_space( \ static_cast>(address_space_enum::constant)))) -#define CK_GENERIC_ADDRESS_SPACE \ - __attribute__((address_space( \ - static_cast>(address_space_enum::generic)))) - template __device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) { diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index c81a64f7ad..b03d2d944d 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -174,7 +174,7 @@ struct GemmKernel if constexpr(std::is_same_v) { - if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) + if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false) { return false; } @@ -185,7 +185,7 @@ struct GemmKernel } else { - if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false) + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) { return false; } @@ -197,7 +197,7 @@ struct GemmKernel if constexpr(std::is_same_v) { - if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false) + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { return false; } @@ -208,7 +208,7 @@ struct GemmKernel } else { - if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) + if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false) { return false; } @@ -220,7 +220,7 @@ struct GemmKernel if constexpr(std::is_same_v) { - if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false) + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { return false; } @@ -231,7 +231,7 @@ struct GemmKernel } else { - if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false) + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) { return false; } @@ -323,17 +323,17 @@ struct GemmKernel const auto& a_tensor_view = views.at(I0); if constexpr(std::is_same_v) { - return pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { - return pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } }(); @@ -341,17 +341,17 @@ struct GemmKernel const auto& b_tensor_view = views.at(I1); if constexpr(std::is_same_v) { - return pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { - return pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } }(); @@ -359,17 +359,17 @@ struct GemmKernel const auto& c_tensor_view = views.at(I2); if constexpr(std::is_same_v) { - return pad_tensor_view( - c_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { - return pad_tensor_view( - c_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } }(); @@ -383,19 +383,19 @@ struct GemmKernel const auto& a_pad_view = views.at(I0); const auto& a_block_window = make_tile_window( a_pad_view, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_m, 0}); const auto& b_pad_view = views.at(I1); const auto& b_block_window = make_tile_window( b_pad_view, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_n, 0}); const auto& c_pad_view = views.at(I2); auto c_block_window = make_tile_window( c_pad_view, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_m, i_n}); return make_tuple(a_block_window, b_block_window, c_block_window); @@ -426,7 +426,7 @@ struct GemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); - ; + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 9606ccc181..ff2a469702 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -12,27 +12,27 @@ struct GemmTilePartitioner { using BlockGemmShape = remove_cvref_t; - static constexpr index_t kM = BlockGemmShape::kM; - static constexpr index_t kN = BlockGemmShape::kN; - static constexpr index_t kK = BlockGemmShape::kK; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) { - index_t GridDimX = (M + kM - 1) / kM; - index_t GridDimY = (N + kN - 1) / kN; + index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; + index_t GridDimY = (N + NPerBlock - 1) / NPerBlock; index_t GridDimZ = batch_size; return dim3(GridDimX, GridDimY, GridDimZ); } CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) -> index_t { - return integer_divide_ceil(K, kK); + return integer_divide_ceil(K, KPerBlock); } CK_TILE_DEVICE auto operator()() { - const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM); - const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN); + const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * MPerBlock); + const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * NPerBlock); return make_tuple(iM, iN); } }; @@ -71,14 +71,14 @@ struct GemmTile1DPartitioner const index_t NBlock = GetNBlock(NBlockSize); const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock); - const index_t iN = __builtin_amdgcn_readfirstlane(fast_mod(blockIdx, NBlock)); + const index_t iN = __builtin_amdgcn_readfirstlane(modulo(blockIdx, NBlock)); return make_tuple(iM, iN); } private: template - CK_TILE_DEVICE auto fast_mod(const TType input, const TType ceil) noexcept(noexcept(ceil != 0)) + CK_TILE_DEVICE auto modulo(const TType input, const TType ceil) noexcept(noexcept(ceil != 0)) -> std::enable_if_t::is_integer, TType> { return input >= ceil ? input - (input / ceil) * ceil : input; diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index bf821415f0..36f4bf03b0 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -13,44 +13,55 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" + namespace ck_tile { -struct GroupedGemmHostArgs +struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs { - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; + CK_TILE_HOST GroupedGemmHostArgs() noexcept = default; + CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t stride_C_) + : GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, 1, M_, N_, K_, stride_A_, stride_B_, stride_C_) + { + } }; template -struct GroupedGemmKernel +struct GroupedGemmKernel : public GemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; + using Base = GemmKernel; + using GemmKernelArgs = typename Base::GemmKernelArgs; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + static constexpr index_t KBatch = 1; + struct GemmTransKernelArg { - GroupedGemmHostArgs group_karg; + GemmKernelArgs group_karg; ck_tile::index_t block_start; ck_tile::index_t block_end; GemmTransKernelArg() = default; - GemmTransKernelArg(GroupedGemmHostArgs&& karg, index_t bl_start, index_t bl_end) + GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end) : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } @@ -106,15 +117,16 @@ struct GroupedGemmKernel grid_size += grid_size_grp; - auto karg = GroupedGemmHostArgs{type_convert(gemm_descs[i].a_ptr), - type_convert(gemm_descs[i].b_ptr), - type_convert(gemm_descs[i].c_ptr), - M, - N, - K, - stride_a, - stride_b, - stride_c}; + auto karg = GemmKernelArgs{type_convert(gemm_descs[i].a_ptr), + type_convert(gemm_descs[i].b_ptr), + type_convert(gemm_descs[i].c_ptr), + M, + N, + K, + stride_a, + stride_b, + stride_c, + /*KBatch*/ KBatch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -132,151 +144,18 @@ struct GroupedGemmKernel const auto [i_m, i_n] = InvokeOffsetCallculationFor1DPartitioner{}( kargs.block_start, kargs.group_karg.N); - // options - const ADataType* a_start = static_cast(kargs.group_karg.a_ptr); - const BDataType* b_start = static_cast(kargs.group_karg.b_ptr); - - // Convert pointers to tensor views - auto a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_start, - make_tuple(kargs.group_karg.M, kargs.group_karg.K), - make_tuple(kargs.group_karg.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_start, - make_tuple(kargs.group_karg.M, kargs.group_karg.K), - make_tuple(1, kargs.group_karg.stride_A), - number<1>{}, - number<1>{}); - } - }(); - - auto b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - b_start, - make_tuple(kargs.group_karg.N, kargs.group_karg.K), - make_tuple(1, kargs.group_karg.stride_B), - number<1>{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - b_start, - make_tuple(kargs.group_karg.N, kargs.group_karg.K), - make_tuple(kargs.group_karg.stride_B, 1), - number{}, - number<1>{}); - } - }(); - - auto a_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - // clang-format on - - auto a_block_window = make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - - auto b_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); + const typename Base::SplitKBatchOffset splitk_batch_offset(kargs.group_karg, blockIdx.z); - auto b_block_window = make_tile_window( - b_pad_view, - make_tuple(number{}, number{}), - {i_n, 0}); + // options + const ADataType* a_ptr = static_cast(kargs.group_karg.a_ptr); + const BDataType* b_ptr = static_cast(kargs.group_karg.b_ptr); + CDataType* c_ptr = static_cast(kargs.group_karg.c_ptr); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - const index_t num_loop = TilePartitioner::GetLoopNum(kargs.group_karg.K); - - // Run GEMM cooperatively by whole wokrgroup. - auto c_block_tile = - GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); - - CDataType* c_start = static_cast(kargs.group_karg.c_ptr); - auto c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_start, - make_tuple(kargs.group_karg.M, kargs.group_karg.N), - make_tuple(kargs.group_karg.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_start, - make_tuple(kargs.group_karg.M, kargs.group_karg.N), - make_tuple(1, kargs.group_karg.stride_C), - number<1>{}, - number<1>{}); - } - }(); - - auto c_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - auto CBlockWindow_pad = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); + this->RunGemm( + a_ptr, b_ptr, c_ptr, smem_ptr, kargs.group_karg, splitk_batch_offset, i_m, i_n); } CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, From cbba68047952fc74240ad4e4a668171f87e29934 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Mon, 6 Jan 2025 14:54:49 +0000 Subject: [PATCH 4/8] Tests fix type --- test/ck_tile/batched_gemm/test_batched_gemm_util.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index e7e9b3d679..08fe500d3c 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include @@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test kOutputRank, 1, 0, - TilePartitioner::kM, - TilePartitioner::kN>>, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock>>, ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem>>; From b781628184e6fa55b2f7fe2f6a23da78e168279d Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Tue, 7 Jan 2025 07:46:32 +0000 Subject: [PATCH 5/8] Removed namespaces --- .../run_grouped_gemm_example.inc | 86 +++++++++---------- include/ck_tile/host/host_tensor.hpp | 10 +-- 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index ac2ced03b9..44e43d082f 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -10,13 +10,13 @@ float invoke_gemm(int n_warmup, const std::vector& args) { - ck_tile::DeviceMem gemm_workspace; + DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(args)); - float ave_time = grouped_gemm( - args, - ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, - gemm_workspace.GetDeviceBuffer()); + float ave_time = + grouped_gemm(args, + stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); std::string op_name{"Grouped Gemm"}; @@ -61,12 +61,12 @@ int run_grouped_gemm_example_with_layouts(int argc, const int repeat = arg_parser.get_int("repeat"); const int warmup = arg_parser.get_int("warmup"); - std::vector Ms = arg_parser.get_int_vec("Ms"); - std::vector Ns = arg_parser.get_int_vec("Ns"); - std::vector Ks = arg_parser.get_int_vec("Ks"); - std::vector stride_As = arg_parser.get_int_vec("stride_As"); - std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); - std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) { @@ -83,17 +83,17 @@ int run_grouped_gemm_example_with_layouts(int argc, } } - std::vector> a_m_k_tensors; - std::vector> b_k_n_tensors; - std::vector> c_m_n_tensors; + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; a_m_k_tensors.reserve(group_count); b_k_n_tensors.reserve(group_count); c_m_n_tensors.reserve(group_count); - std::vector> a_m_k_dev_buf; - std::vector> b_k_n_dev_buf; - std::vector> c_m_n_dev_buf; + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; a_m_k_dev_buf.reserve(group_count); b_k_n_dev_buf.reserve(group_count); @@ -104,34 +104,34 @@ int run_grouped_gemm_example_with_layouts(int argc, for(int i = 0; i < group_count; ++i) { - const ck_tile::index_t M = Ms[i]; - const ck_tile::index_t N = Ns[i]; - const ck_tile::index_t K = Ks[i]; + const index_t M = Ms[i]; + const index_t N = Ns[i]; + const index_t K = Ks[i]; - stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout); - stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout); - stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{}); + stride_As[i] = get_default_stride(M, N, stride_As[i], a_layout); + stride_Bs[i] = get_default_stride(K, N, stride_Bs[i], b_layout); + stride_Cs[i] = get_default_stride(M, N, stride_Cs[i], CLayout{}); - a_m_k_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout))); - b_k_n_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); - c_m_n_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); + a_m_k_tensors.push_back( + HostTensor(host_tensor_descriptor(M, K, stride_As[i], a_layout))); + b_k_n_tensors.push_back( + HostTensor(host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); + c_m_n_tensors.push_back( + HostTensor(host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl; - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tensors[i]); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors[i]); + FillUniformDistribution{-5.f, 5.f}(a_m_k_tensors[i]); + FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors[i]); - a_m_k_dev_buf.push_back(std::make_unique( - a_m_k_tensors[i].get_element_space_size_in_bytes())); - b_k_n_dev_buf.push_back(std::make_unique( - b_k_n_tensors[i].get_element_space_size_in_bytes())); - c_m_n_dev_buf.push_back(std::make_unique( - c_m_n_tensors[i].get_element_space_size_in_bytes())); + a_m_k_dev_buf.push_back( + std::make_unique(a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back( + std::make_unique(b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back( + std::make_unique(c_m_n_tensors[i].get_element_space_size_in_bytes())); a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); @@ -157,12 +157,12 @@ int run_grouped_gemm_example_with_layouts(int argc, { for(int i = 0; i < group_count; ++i) { - ck_tile::HostTensor c_m_n_host_ref( - ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); + HostTensor c_m_n_host_ref( + host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); c_m_n_host_ref.SetZero(); - ck_tile::reference_gemm( + reference_gemm( a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); - pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref); + pass &= check_err(c_m_n_tensors[i], c_m_n_host_ref); } std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; } @@ -181,8 +181,8 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string a_layout = arg_parser.get_str("a_layout"); const std::string b_layout = arg_parser.get_str("b_layout"); - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Row = tensor_layout::gemm::RowMajor; + using Col = tensor_layout::gemm::ColumnMajor; if(a_layout == "R" && b_layout == "C") { diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 10313e4207..2babb2afe9 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -684,13 +684,13 @@ auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride { using namespace ck_tile::literals; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { - return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}); } else { - return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}); } } template @@ -698,7 +698,7 @@ auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TL { if(stride == 0) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return col; } From 774f9033eb3834f35872bf88060d62b7c1684371 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Tue, 7 Jan 2025 07:54:53 +0000 Subject: [PATCH 6/8] Add template param to avoid implicit cast --- include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index ff2a469702..f7a34dafcb 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -64,9 +64,10 @@ struct GemmTile1DPartitioner return integer_divide_ceil(K, KPerBlock); } - CK_TILE_DEVICE auto + template + CK_TILE_DEVICE auto constexpr operator()(index_t blockIdx, index_t NBlockSize) noexcept(noexcept(GetNBlock(NBlockSize) != 0)) - -> tuple + -> const tuple { const index_t NBlock = GetNBlock(NBlockSize); From 0c8a5793eec42d936020129554836c44fb2888c5 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Tue, 7 Jan 2025 08:10:54 +0000 Subject: [PATCH 7/8] Remove generic function --- .../run_grouped_gemm_example.inc | 86 +++++++++---------- .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 17 ++-- 2 files changed, 50 insertions(+), 53 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 44e43d082f..ac2ced03b9 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -10,13 +10,13 @@ float invoke_gemm(int n_warmup, const std::vector& args) { - DeviceMem gemm_workspace; + ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(args)); - float ave_time = - grouped_gemm(args, - stream_config{nullptr, true, 1, n_warmup, n_repeat}, - gemm_workspace.GetDeviceBuffer()); + float ave_time = grouped_gemm( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); std::string op_name{"Grouped Gemm"}; @@ -61,12 +61,12 @@ int run_grouped_gemm_example_with_layouts(int argc, const int repeat = arg_parser.get_int("repeat"); const int warmup = arg_parser.get_int("warmup"); - std::vector Ms = arg_parser.get_int_vec("Ms"); - std::vector Ns = arg_parser.get_int_vec("Ns"); - std::vector Ks = arg_parser.get_int_vec("Ks"); - std::vector stride_As = arg_parser.get_int_vec("stride_As"); - std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); - std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) { @@ -83,17 +83,17 @@ int run_grouped_gemm_example_with_layouts(int argc, } } - std::vector> a_m_k_tensors; - std::vector> b_k_n_tensors; - std::vector> c_m_n_tensors; + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; a_m_k_tensors.reserve(group_count); b_k_n_tensors.reserve(group_count); c_m_n_tensors.reserve(group_count); - std::vector> a_m_k_dev_buf; - std::vector> b_k_n_dev_buf; - std::vector> c_m_n_dev_buf; + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; a_m_k_dev_buf.reserve(group_count); b_k_n_dev_buf.reserve(group_count); @@ -104,34 +104,34 @@ int run_grouped_gemm_example_with_layouts(int argc, for(int i = 0; i < group_count; ++i) { - const index_t M = Ms[i]; - const index_t N = Ns[i]; - const index_t K = Ks[i]; + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; - stride_As[i] = get_default_stride(M, N, stride_As[i], a_layout); - stride_Bs[i] = get_default_stride(K, N, stride_Bs[i], b_layout); - stride_Cs[i] = get_default_stride(M, N, stride_Cs[i], CLayout{}); + stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{}); - a_m_k_tensors.push_back( - HostTensor(host_tensor_descriptor(M, K, stride_As[i], a_layout))); - b_k_n_tensors.push_back( - HostTensor(host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); - c_m_n_tensors.push_back( - HostTensor(host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl; - FillUniformDistribution{-5.f, 5.f}(a_m_k_tensors[i]); - FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors[i]); - a_m_k_dev_buf.push_back( - std::make_unique(a_m_k_tensors[i].get_element_space_size_in_bytes())); - b_k_n_dev_buf.push_back( - std::make_unique(b_k_n_tensors[i].get_element_space_size_in_bytes())); - c_m_n_dev_buf.push_back( - std::make_unique(c_m_n_tensors[i].get_element_space_size_in_bytes())); + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); @@ -157,12 +157,12 @@ int run_grouped_gemm_example_with_layouts(int argc, { for(int i = 0; i < group_count; ++i) { - HostTensor c_m_n_host_ref( - host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); c_m_n_host_ref.SetZero(); - reference_gemm( + ck_tile::reference_gemm( a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); - pass &= check_err(c_m_n_tensors[i], c_m_n_host_ref); + pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref); } std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; } @@ -181,8 +181,8 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string a_layout = arg_parser.get_str("a_layout"); const std::string b_layout = arg_parser.get_str("b_layout"); - using Row = tensor_layout::gemm::RowMajor; - using Col = tensor_layout::gemm::ColumnMajor; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if(a_layout == "R" && b_layout == "C") { diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index f7a34dafcb..d94d13eba5 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -64,10 +64,9 @@ struct GemmTile1DPartitioner return integer_divide_ceil(K, KPerBlock); } - template CK_TILE_DEVICE auto constexpr operator()(index_t blockIdx, index_t NBlockSize) noexcept(noexcept(GetNBlock(NBlockSize) != 0)) - -> const tuple + -> const tuple { const index_t NBlock = GetNBlock(NBlockSize); @@ -78,9 +77,8 @@ struct GemmTile1DPartitioner } private: - template - CK_TILE_DEVICE auto modulo(const TType input, const TType ceil) noexcept(noexcept(ceil != 0)) - -> std::enable_if_t::is_integer, TType> + CK_TILE_DEVICE auto constexpr modulo(index_t input, index_t ceil) noexcept(noexcept(ceil != 0)) + -> index_t { return input >= ceil ? input - (input / ceil) * ceil : input; } @@ -89,16 +87,15 @@ struct GemmTile1DPartitioner template struct InvokeOffsetCallculationFor1DPartitioner { - template - CK_TILE_DEVICE constexpr auto operator()(TType block_start, TType NBlockSize) - -> const std::enable_if_t::is_integer, tuple> + CK_TILE_DEVICE constexpr auto operator()(index_t block_start, index_t NBlockSize) + -> const tuple { const auto [iM, iN] = PartitionerFn{}(blockIdx.x - block_start, NBlockSize); const auto iM_to_img_corr = iM * PartitionerFn::MPerBlock; const auto iN_to_img_corr = iN * PartitionerFn::NPerBlock; - const TType i_m = __builtin_amdgcn_readfirstlane(iM_to_img_corr); - const TType i_n = __builtin_amdgcn_readfirstlane(iN_to_img_corr); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM_to_img_corr); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN_to_img_corr); return make_tuple(i_m, i_n); } From 2f80a6a061c0f4fbca66c25f598c92392ac04182 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Tue, 7 Jan 2025 08:17:10 +0000 Subject: [PATCH 8/8] Constant value --- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 36f4bf03b0..24f6beab58 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -3,17 +3,11 @@ #pragma once -#include -#include - #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/host.hpp" - #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/host.hpp" namespace ck_tile { @@ -29,9 +23,12 @@ struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_) - : GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, 1, M_, N_, K_, stride_A_, stride_B_, stride_C_) + : GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_) { } + + private: + static constexpr index_t KBatch = 1; }; template @@ -126,7 +123,7 @@ struct GroupedGemmKernel : public GemmKernel