Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CK TILE] GEMM and Batched GEMM SplitK support #1724

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from

Conversation

bartekxk
Copy link
Contributor

@bartekxk bartekxk commented Dec 6, 2024

No description provided.

Copy link
Collaborator

@aosewski aosewski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work! However I have few things for reconsideration.

include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Outdated Show resolved Hide resolved
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Outdated Show resolved Hide resolved
Comment on lines +123 to 124
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you could add third function parameter memory_operation_enum o_mem_data_op = out_memory_data_op ? Then you wouldn't have to pass all template parameters, but just pass memory op if you need different one than default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not possible to pass it because it is enum. Then you cannot compare object from argument in if constexpr

Copy link
Collaborator

@aosewski aosewski Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bartekxk You're right. What about having a store_tile API which get's as as last paramter memory operation enum ? And for set it will do store while for atomic_add and others it will do update ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it at now there are store_tile and update_tile, so it is just concept

include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Outdated Show resolved Hide resolved
example/ck_tile/03_gemm/gemm_basic.hpp Outdated Show resolved Hide resolved
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Outdated Show resolved Hide resolved
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Outdated Show resolved Hide resolved
@@ -158,12 +167,26 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
Copy link
Collaborator

@aosewski aosewski Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bartekxk By the way do we really need here this check? The *_raw version of tile API just does things using assembly... I'm not sure if we really need it here. The plain tile API should work as well regardless of padding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure we need it here? Looks like it could improve performance like for example here: #1752

@bartekxk bartekxk changed the title [CK TILE] GEMM SplitK support [CK TILE] GEMM and Batched GEMM SplitK support Dec 22, 2024
@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size")
.insert("m", "3840", "m dimension")
arg_parser.insert("m", "3840", "m dimension")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we not supporting batch (b) in this example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simple gemm, not batched.

@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;

const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it true that if we set the split_k=1 from cmd arg, the kernel will run only K_Tile for each kernel's unroll? what about if we want to disable split-k from cmd args, is it through split_k=0? or not considered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is analogus to just round up K dimension in the case of split_k=1

@bartekxk bartekxk requested a review from carlushuang December 23, 2024 11:20
Copy link
Contributor

@carlushuang carlushuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants