diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 9aa1411b4a25..b94cc9ce5086 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -23,9 +23,14 @@ typedef __hip_bfloat162 __nv_bfloat162; typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16_raw __nv_bfloat16_raw; - + #if defined(HIP_FP8_TYPE_OCP) typedef __hip_fp8_e4m3 __nv_fp8_e4m3; typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3; + #else +// ROCm 6.2 fallback: only *_fnuz types exist +typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3; + #endif #endif #include "core/registration.h" diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index dac9df6048f2..133a545045b1 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -25,6 +25,12 @@ #include "../attention/dtype_fp8.cuh" #include "../quantization/fp8/amd/quant_utils.cuh" +// ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent +#if !defined(HIP_FP8_TYPE_OCP) +using __hip_fp8_e4m3 = __hip_fp8_e4m3_fnuz; +using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz; +#endif + #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) #define __HIP__GFX9__