Skip to content

Commit b63a984

Browse files
authored
Restricting FP8 wvSplitk to MI300x (#439)
1 parent 955ba64 commit b63a984

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

csrc/rocm/custom_kernels.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
#define __HIP__MI300_MI250__
1111
#endif
1212

13+
#if defined(__HIPCC__) && \
14+
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
15+
#define __HIP__MI300__
16+
#endif
17+
1318
#if defined(NDEBUG)
1419
#undef NDEBUG
1520
#include <assert.h>
@@ -357,7 +362,7 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
357362
return rtn;
358363
}*/
359364

360-
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
365+
#if defined(__HIP__MI300__) // TODO: Add NAVI support
361366
template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
362367
__global__ void __launch_bounds__(WvPrGrp* THRDS)
363368
wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B,
@@ -534,7 +539,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
534539
n += CuCount * _WvPrGrp * YTILE;
535540
}
536541
}
537-
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
542+
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
538543
template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
539544
__global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N,
540545
const DTYPE* B, const DTYPE* __restrict__ A,
@@ -544,9 +549,9 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N,
544549
const int CuCount) {
545550
UNREACHABLE_CODE
546551
}
547-
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
552+
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
548553

549-
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
554+
#if defined(__HIP__MI300__) // TODO: Add NAVI support
550555
template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
551556
__global__ void __launch_bounds__(WvPrGrp* THRDS)
552557
wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B,
@@ -722,7 +727,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
722727
n += CuCount * _WvPrGrp * YTILE;
723728
}
724729
}
725-
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
730+
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
726731
template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
727732
__global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N,
728733
const DTYPE* B, const DTYPE* __restrict__ A,
@@ -731,7 +736,7 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N,
731736
const int Otp, const int CuCount) {
732737
UNREACHABLE_CODE
733738
}
734-
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
739+
#endif // defined(__HIP__MI300__) TODO: Add NAVI support
735740

736741
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
737742
// This version targets cases where A[] fits LDS capacity

vllm/model_executor/layers/tuned_gemm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm import _custom_ops as ops
1111
from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM
1212
from vllm.platforms import current_platform
13-
from vllm.utils import is_navi
13+
from vllm.utils import is_mi250, is_navi
1414

1515
support_tuned_gemms = False
1616
if current_platform.is_rocm():
@@ -102,7 +102,8 @@ def scaled_mm(
102102
bias: Optional[torch.Tensor],
103103
) -> torch.Tensor:
104104
n = inp.shape[0]
105-
if n != 1:
105+
if (not VLLM_USE_ROCM_SKINNY_GEMM or n != 1
106+
or not current_platform.is_rocm() or is_mi250() or is_navi()):
106107
return torch._scaled_mm(inp,
107108
weight,
108109
out_dtype=out_dtype,

0 commit comments

Comments
 (0)