Skip to content

Commit 2357480

Browse files
[BugFix] Fix UB in per_token_group_quant.cu (#24913)
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com>
1 parent f11e3c5 commit 2357480

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

csrc/quantization/fp8/per_token_group_quant.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
#include "../vectorization_utils.cuh"
1313
#include "../../dispatch_utils.h"
1414

15-
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
16-
unsigned mask = 0xffff;
15+
__device__ __forceinline__ float GroupReduceMax(float val) {
16+
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
1717

1818
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
1919
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
@@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel(
8686
threads_per_group, // stride in group
8787
scalar_op_cache); // scalar handler
8888

89-
local_absmax = GroupReduceMax(local_absmax, lane_id);
89+
local_absmax = GroupReduceMax(local_absmax);
9090

9191
float y_s = local_absmax / max_8bit;
9292
if constexpr (SCALE_UE8M0) {

0 commit comments

Comments
 (0)