-
Notifications
You must be signed in to change notification settings - Fork 136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CK-Tile Grouped GEMM refactor and post PR fixes #1756
base: develop
Are you sure you want to change the base?
Conversation
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) | ||
|
||
template <typename T> | ||
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) | ||
{ | ||
// cast a pointer in "Constant" address space (4) to "Generic" address space (0) | ||
// only c-style pointer cast seems be able to be compiled | ||
#pragma clang diagnostic push | ||
#pragma clang diagnostic ignored "-Wold-style-cast" | ||
return (T*)p; // NOLINT(old-style-cast) | ||
#pragma clang diagnostic pop | ||
} | ||
|
||
template <typename T> | ||
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) | ||
{ | ||
// cast a pointer in "Generic" address space (0) to "Constant" address space (4) | ||
// only c-style pointer cast seems be able to be compiled | ||
#pragma clang diagnostic push | ||
#pragma clang diagnostic ignored "-Wold-style-cast" | ||
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) | ||
#pragma clang diagnostic pop | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like here the address space id is not consistent with the one here:
https://github.com/ROCm/composable_kernel/pull/1756/files#diff-5fbd9be40988c2586d9a6d3568593c4a9c5a5fb3a718eded871578acd7f34b8cR15-R22
Please make it consistent and add constant
to the enum.
@@ -109,4 +109,28 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) | |||
#endif | |||
} | |||
|
|||
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to use enumerator address_space_enum
here?
{ | ||
using namespace ck_tile::literals; | ||
|
||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're already in ck_tile
namespace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, it's good practice to maintain namespaces even inside them. This indicates its exact place of origin.
index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * MPerBlock); | ||
index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.x * NPerBlock); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect.
using TilePartitioner = remove_cvref_t<TilePartitioner_>; | ||
using GemmPipeline = remove_cvref_t<GemmPipeline_>; | ||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; | ||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>; | ||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>; | ||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>; | ||
using Block2ETileMap = OffsettedBlockToCTileMap<TilePartitioner>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use CK-Tile naming convention and not the old CK's one.
// Block2CTileMap configuration parameter. | ||
static constexpr index_t B2E_M01 = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this shouldn't be needed.
|
||
const index_t block_start = grid_size; | ||
const index_t block_end = grid_size + grid_size_grp; | ||
|
||
grid_size += grid_size_grp; | ||
|
||
auto grouped_block_2_ctile_map = Block2ETileMap(B2E_M01, M, N); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be the offseted TilePartitioner. Take a look here: https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
@@ -62,12 +62,10 @@ struct GemmTile1DPartitioner | |||
return integer_divide_ceil(K, KPerBlock); | |||
} | |||
|
|||
CK_TILE_DEVICE auto operator()(index_t blockOffset, index_t NBlockSize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now the function in https://github.com/ROCm/composable_kernel/pull/1756/files#diff-1452d09e42d9aed38e087f552e940394aae58d72f2991fa842fa15caced37854R55 is not used anywhere.
Will be off from tomorrow till 2nd Jan. Please pass on review to ie. @bartekxk
@@ -26,6 +26,7 @@ | |||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" | |||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" | |||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" | |||
#include "ck_tile/ops/gemm/kernel/gemm_offset_block.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run python3 remod.py
under include/ck_tile
, you may find we will sorting the headers before include into this gemm.hpp. Actually headers under include/ck_tile/ops/
better use this script to be autogenerated :) in case ppl argue about alphbetical order
This pull-request contains changes as following:
get_stride
function, etc., to a different location,