Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@
typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat16_raw __nv_bfloat16_raw;

#if defined(HIP_FP8_TYPE_OCP)
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
#else
// ROCm 6.2 fallback: only *_fnuz types exist
typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3;
typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3;
#endif
#endif

#include "core/registration.h"
Expand Down
6 changes: 6 additions & 0 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"

// ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent
#if !defined(HIP_FP8_TYPE_OCP)
using __hip_fp8_e4m3 = __hip_fp8_e4m3_fnuz;
using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
#endif

#if defined(__HIPCC__) && \
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
#define __HIP__GFX9__
Expand Down