@@ -58,18 +58,35 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
5858 }
5959}
6060
61- std::pair<int , int > Gemm::ComputeWarpPartition (int num_warps, Target target,
62- bool maybe_hopper_wgmma) const {
61+ Gemm::GemmInst Gemm::GetGemmInst (int block_size, Target target) const {
62+ int warp_size = TargetGetWarpSize (target);
63+ int num_warps = block_size / warp_size;
64+ bool allow_wgmma = TargetIsHopper (target) && (this ->M >= 64 ) &&
65+ (num_warps % 4 == 0 ) && CheckWGMMA ();
66+ if (allow_wgmma) {
67+ return GemmInst::kWGMMA ;
68+ } else if (TargetIsCDNA (target)) {
69+ return GemmInst::kMFMA ;
70+ } else if (TargetIsCuda (target)) {
71+ return GemmInst::kMMA ;
72+ } else {
73+ ICHECK (0 ) << " Unsupported target for gemm: " << target->str ();
74+ }
75+ }
76+
77+ std::pair<int , int > Gemm::ComputeWarpPartition (int block_size,
78+ GemmInst gemm_inst,
79+ Target target) const {
80+ int num_warps = block_size / TargetGetWarpSize (target);
6381 int m_warp = 1 , n_warp = 1 ;
6482 constexpr int kMPerWarp = 16 ; // Rows processed by a single warp
6583 constexpr int kNPerWarp = 8 ; // Columns processed by a single warp
66- bool allow_wgmma = TargetIsHopper (target) && maybe_hopper_wgmma &&
67- (this ->M >= 64 ) && (num_warps % 4 == 0 );
84+
6885 ICHECK (this ->M % kMPerWarp == 0 )
6986 << " M must be divisible by " << kMPerWarp << " , but got " << this ->M ;
7087 ICHECK (this ->N % kNPerWarp == 0 )
7188 << " N must be divisible by " << kNPerWarp << " , but got " << this ->N ;
72- if (allow_wgmma ) {
89+ if (gemm_inst == GemmInst:: kWGMMA ) {
7390 ICHECK (num_warps % 4 == 0 ) << " Warp-Group MMA requires 128×k threads." ;
7491
7592 constexpr int kGroup = 4 ; // Number of warps in a warp-group
@@ -268,16 +285,9 @@ bool Gemm::CheckWGMMA() const {
268285}
269286
270287Stmt Gemm::Lower (const LowerArgs &T, arith::Analyzer *analyzer) const {
271- int warp_size = 32 ;
272- if (TargetIsCDNA (T.target )) {
273- warp_size = 64 ;
274- }
275288 auto block_size = *as_const_int (T.thread_bounds ->extent );
276- bool maybe_wgmma = TargetIsHopper (T.target ) && (this ->M >= 64 ) &&
277- (block_size / warp_size % 4 == 0 ) && CheckWGMMA ();
278-
279- auto [warp_m, warp_n] =
280- ComputeWarpPartition (block_size / warp_size, T.target , maybe_wgmma);
289+ GemmInst gemm_inst = GetGemmInst (block_size, T.target );
290+ auto [warp_m, warp_n] = ComputeWarpPartition (block_size, gemm_inst, T.target );
281291
282292 std::stringstream ss;
283293 std::string op_name = " tl::gemm_ss" ;
@@ -295,7 +305,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
295305 // for cdna gemm, we need to specify kPack
296306 ss << " , " << kPack ;
297307 } else if (TargetIsHopper (T.target )) {
298- ss << " , " << (maybe_wgmma ? " true" : " false" );
308+ ss << " , " << (gemm_inst == GemmInst:: kWGMMA ? " true" : " false" );
299309 }
300310 if (wg_wait != 0 ) {
301311 ss << " , " << wg_wait;
@@ -321,10 +331,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
321331 ICHECK (C.scope () == " local.fragment" );
322332 auto thread_range = T.thread_bounds ;
323333 auto block_size = *as_const_int (thread_range->extent );
334+ GemmInst gemm_inst = GetGemmInst (block_size, T.target );
335+ auto [warp_m, warp_n] = ComputeWarpPartition (block_size, gemm_inst, T.target );
336+
324337 if (TargetIsVolta (T.target )) {
325- const int warp_size = 32 ;
326- auto [warp_m, warp_n] =
327- ComputeWarpPartition (block_size / warp_size, T.target );
328338 auto fragment =
329339 makeGemmVoltaFragmentC (M, N, M / warp_m, N / warp_n, C->dtype .bits ());
330340 results.Set (C, fragment->BindThreadRange (thread_range));
@@ -347,9 +357,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
347357 *as_const_int (B->shape [dim_B - 1 ]),
348358 false , trans_B ? 2 : 1 ));
349359 } else if (TargetIsAmpere (T.target ) || TargetIsTuring (T.target )) {
350- const int warp_size = 32 ;
351- auto [warp_m, warp_n] =
352- ComputeWarpPartition (block_size / warp_size, T.target );
353360 auto fragment =
354361 makeGemmFragmentC (M, N, M / warp_m, N / warp_n, C->dtype .bits ());
355362 results.Set (C, fragment->BindThreadRange (thread_range));
@@ -383,13 +390,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
383390 ICHECK (0 );
384391 }
385392 } else if (TargetIsHopper (T.target )) {
386- const int warp_size = 32 ;
387- bool maybe_wgmma =
388- (this ->M >= 64 ) && (block_size / warp_size % 4 == 0 ) && CheckWGMMA ();
389- auto [warp_m, warp_n] =
390- ComputeWarpPartition (block_size / warp_size, T.target , maybe_wgmma);
391393 auto fragment =
392- maybe_wgmma
394+ gemm_inst == GemmInst:: kWGMMA
393395 ? makeGemmFragmentCHopper (M, N, M / warp_m, N / warp_n,
394396 C->dtype .bits ())
395397 : makeGemmFragmentC (M, N, M / warp_m, N / warp_n, C->dtype .bits ());
@@ -401,7 +403,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
401403 const int64_t continuity =
402404 trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
403405 auto ABLayout =
404- maybe_wgmma
406+ gemm_inst == GemmInst:: kWGMMA
405407 ? makeGemmABLayoutHopper (mat_stride, mat_continuous, continuity,
406408 A->dtype .bits (), trans_A ? 1 : 2 )
407409 : makeGemmABLayout (mat_stride, mat_continuous, mat_continuous,
@@ -419,7 +421,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
419421 const int64_t continuity =
420422 trans_B ? mat_continuous : mat_continuous / warp_n;
421423 auto ABLayout =
422- maybe_wgmma
424+ gemm_inst == GemmInst:: kWGMMA
423425 ? makeGemmABLayoutHopper (mat_stride, mat_continuous, continuity,
424426 B->dtype .bits (), trans_B ? 2 : 1 )
425427 : makeGemmABLayout (mat_stride, mat_continuous, mat_continuous,
@@ -429,10 +431,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
429431 ICHECK (0 ) << " WGMMA only support B in shared." ;
430432 }
431433 } else if (TargetIsCDNA (T.target )) {
432- const int warp_size = 64 ;
433- auto [warp_m, warp_n] =
434- ComputeWarpPartition (block_size / warp_size, T.target );
435-
436434 auto fragment =
437435 makeGemmFragmentCCDNA (M, N, M / warp_m, N / warp_n, C->dtype .bits ());
438436 results.Set (C, fragment->BindThreadRange (thread_range));
0 commit comments