-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CK TILE] GEMM and Batched GEMM SplitK support #1724
base: develop
Are you sure you want to change the base?
Changes from all commits
13707b4
7f523ec
6f677a8
9ad07c7
b64e852
86d7bc1
4a7f78d
d1dc19d
bce0f24
d1d7909
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
#endif | ||
|
||
template <typename ALayout, typename BLayout, typename CLayout> | ||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) | ||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) | ||
{ | ||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) | ||
// Memory friendly for Interwave scheduler | ||
|
@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) | |
#endif | ||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; | ||
|
||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); | ||
const ck_tile::index_t k_grain = args.k_batch * K_Tile; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it true that if we set the split_k=1 from cmd arg, the kernel will run only K_Tile for each kernel's unroll? what about if we want to disable split-k from cmd args, is it through split_k=0? or not considered? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is analogus to just round up K dimension in the case of split_k=1 |
||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; | ||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); | ||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); | ||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); | ||
|
||
|
@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) | |
has_hot_loop_v, | ||
tail_number_v>>; | ||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; | ||
auto kargs = Kernel::MakeKargs(args.p_a, | ||
args.p_b, | ||
args.p_c, | ||
args.M, | ||
args.N, | ||
args.K, | ||
args.stride_A, | ||
args.stride_B, | ||
args.stride_C); | ||
|
||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); | ||
auto kargs = Kernel::MakeKernelArgs(args); | ||
|
||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); | ||
constexpr dim3 blocks = Kernel::BlockSize(); | ||
|
||
if(!Kernel::IsSupportedArgument(kargs)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. | ||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#pragma once | ||
|
||
|
@@ -56,6 +56,13 @@ struct CShuffleEpilogue | |
// No additional shared memory needed | ||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } | ||
|
||
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() | ||
{ | ||
// TODO: At now CShuffle doesn't allow to vector store after permute. | ||
// It should be fixed and this function should return true. | ||
return false; | ||
} | ||
|
||
template <typename OAccTile> | ||
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) | ||
{ | ||
|
@@ -111,7 +118,9 @@ struct CShuffleEpilogue | |
} | ||
} | ||
|
||
template <typename ODramWindowTmp, typename OAccTile> | ||
template <typename ODramWindowTmp, | ||
typename OAccTile, | ||
memory_operation_enum out_memory_data_op = memory_operation_enum::set> | ||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) | ||
Comment on lines
+123
to
124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if you could add third function parameter There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is not possible to pass it because it is enum. Then you cannot compare object from argument in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bartekxk You're right. What about having a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do it at now there are store_tile and update_tile, so it is just concept |
||
{ | ||
const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); | ||
|
@@ -158,12 +167,26 @@ struct CShuffleEpilogue | |
// Store the tile data to the permuted location | ||
if constexpr(kPadM || kPadN) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bartekxk By the way do we really need here this check? The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure we need it here? Looks like it could improve performance like for example here: #1752 |
||
{ | ||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); | ||
if constexpr(out_memory_data_op == memory_operation_enum::set) | ||
{ | ||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); | ||
} | ||
else | ||
{ | ||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); | ||
} | ||
buffer_store_fence(); | ||
} | ||
else | ||
{ | ||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); | ||
if constexpr(out_memory_data_op == memory_operation_enum::set) | ||
{ | ||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); | ||
} | ||
else | ||
{ | ||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); | ||
} | ||
} | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we not supporting batch (
b
) in this example?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is simple gemm, not batched.