@@ -83,15 +83,15 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
83
83
// This is the BF16 {-136, -136} represented as an integer.
84
84
#if defined(USE_ROCM)
85
85
#if ROCM_VERSION >= 60200
86
- auto BF16_BIAS = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0xC308 }));
87
- auto BF16_ONE = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0x3F80 }));
86
+ auto BF16_SCALE_FACTOR = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0xC308 }));
87
+ auto BF16_UNIT_VALUE = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0x3F80 }));
88
88
#else
89
- auto BF16_BIAS = __bfloat162bfloat162 (__hip_bfloat16{0xC308 });
90
- auto BF16_ONE = __bfloat162bfloat162 (__hip_bfloat16{0x3F80 });
89
+ auto BF16_SCALE_FACTOR = __bfloat162bfloat162 (__hip_bfloat16{0xC308 });
90
+ auto BF16_UNIT_VALUE = __bfloat162bfloat162 (__hip_bfloat16{0x3F80 });
91
91
#endif
92
92
#else
93
- static constexpr uint32_t BF16_BIAS = 0xC308C308 ;
94
- static constexpr uint32_t BF16_ONE = 0x3F803F80 ;
93
+ static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308 ;
94
+ static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80 ;
95
95
#endif
96
96
97
97
// Finally, we construct the output numbers.
@@ -100,11 +100,11 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
100
100
// Since this section is for Ampere+, we use bf16 fma to do the bias
101
101
// subtraction
102
102
#if defined(USE_ROCM)
103
- result.vals [ii] = __hfma2 (result.vals [ii], BF16_ONE, BF16_BIAS );
103
+ result.vals [ii] = __hfma2 (result.vals [ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR );
104
104
#else
105
105
asm (" fma.rn.bf16x2 %0, %1, %2, %3;\n "
106
106
: " =r" (h[ii])
107
- : " r" (h[ii]), " r" (BF16_ONE ), " r" (BF16_BIAS ));
107
+ : " r" (h[ii]), " r" (BF16_UNIT_VALUE ), " r" (BF16_SCALE_FACTOR ));
108
108
#endif
109
109
}
110
110
@@ -369,3 +369,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
369
369
}
370
370
371
371
#endif
372
+ git checkout main -- file.txt
0 commit comments