1010 #define __HIP__MI300_MI250__
1111#endif
1212
13+ #if defined(__HIPCC__) && \
14+ (defined (__gfx940__) || defined (__gfx941__) || defined (__gfx942__))
15+ #define __HIP__MI300__
16+ #endif
17+
1318#if defined(NDEBUG)
1419 #undef NDEBUG
1520 #include < assert.h>
@@ -357,7 +362,7 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
357362 return rtn;
358363}*/
359364
360- #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
365+ #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
361366template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
362367__global__ void __launch_bounds__ (WvPrGrp* THRDS)
363368 wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B,
@@ -534,7 +539,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
534539 n += CuCount * _WvPrGrp * YTILE;
535540 }
536541}
537- #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
542+ #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
538543template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
539544__global__ void wvSpltKQ_hf_sml_ (const int K, const int Kp, const int N,
540545 const DTYPE* B, const DTYPE* __restrict__ A,
@@ -544,9 +549,9 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N,
544549 const int CuCount) {
545550 UNREACHABLE_CODE
546551}
547- #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
552+ #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
548553
549- #if defined(__HIP__MI300_MI250__ ) // TODO: Add NAVI support
554+ #if defined(__HIP__MI300__ ) // TODO: Add NAVI support
550555template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
551556__global__ void __launch_bounds__ (WvPrGrp* THRDS)
552557 wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B,
@@ -722,7 +727,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
722727 n += CuCount * _WvPrGrp * YTILE;
723728 }
724729}
725- #else // !defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
730+ #else // !defined(__HIP__MI300__ ) TODO: Add NAVI support
726731template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
727732__global__ void wvSpltKQ_hf_ (const int K, const int Kp, const int N,
728733 const DTYPE* B, const DTYPE* __restrict__ A,
@@ -731,7 +736,7 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N,
731736 const int Otp, const int CuCount) {
732737 UNREACHABLE_CODE
733738}
734- #endif // defined(__HIP__MI300_MI250__ ) TODO: Add NAVI support
739+ #endif // defined(__HIP__MI300__ ) TODO: Add NAVI support
735740
736741#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
737742// This version targets cases where A[] fits LDS capacity
0 commit comments