diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index f01427cc3d0c..feda497d0210 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -19,12 +19,24 @@ __device__ __forceinline__ fp8_type cvt_c10(float const r) { return {}; } +// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro +// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes +// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES +// on ROCm instantiates both OCP and FNUZ kernels, we need to replace +// the new HW cvt with something reasonable that doesn't rely on the +// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer. template <> __device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) { + #if HIP_FP8_TYPE_OCP return c10::Float8_e4m3fn( __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation, __hip_fp8_e4m3::__default_interpret), c10::Float8_e4m3fn::from_bits()); + #else + // Cast implemented by pytorch. Uses bit manipulation instead of HW cvt. + // HW cvt above is faster when it is available (ROCm 6.3 or newer). + return static_cast(r); + #endif } template <>