Skip to content

Commit 26fa19c

Browse files
committed
better naming
1 parent 253c188 commit 26fa19c

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,15 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
8383
// This is the BF16 {-136, -136} represented as an integer.
8484
#if defined(USE_ROCM)
8585
#if ROCM_VERSION >= 60200
86-
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
87-
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
86+
auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
87+
auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
8888
#else
89-
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308});
90-
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
89+
auto BF16_SCALE_FACTOR = __bfloat162bfloat162(__hip_bfloat16{0xC308});
90+
auto BF16_UNIT_VALUE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
9191
#endif
9292
#else
93-
static constexpr uint32_t BF16_BIAS = 0xC308C308;
94-
static constexpr uint32_t BF16_ONE = 0x3F803F80;
93+
static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308;
94+
static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80;
9595
#endif
9696

9797
// Finally, we construct the output numbers.
@@ -100,11 +100,11 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
100100
// Since this section is for Ampere+, we use bf16 fma to do the bias
101101
// subtraction
102102
#if defined(USE_ROCM)
103-
result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS);
103+
result.vals[ii] = __hfma2(result.vals[ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR);
104104
#else
105105
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
106106
: "=r"(h[ii])
107-
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
107+
: "r"(h[ii]), "r"(BF16_UNIT_VALUE), "r"(BF16_SCALE_FACTOR));
108108
#endif
109109
}
110110

@@ -369,3 +369,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
369369
}
370370

371371
#endif
372+
git checkout main -- file.txt

0 commit comments

Comments
 (0)