Skip to content

Commit 221bf72

Browse files
authored
output type conversion fix (#27159)
1 parent b3aba04 commit 221bf72

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

vllm/model_executor/layers/batch_invariant.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,7 @@ def matmul_kernel_persistent(
134134
bias_ptrs = bias_ptr + offs_cn
135135
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
136136
accumulator += bias
137-
if c_ptr.dtype.element_ty == tl.float8e4nv:
138-
c = accumulator.to(tl.float8e4nv)
139-
else:
140-
c = accumulator.to(tl.float16)
137+
c = accumulator.to(c_ptr.dtype.element_ty)
141138
tl.store(c_ptrs, c, mask=c_mask)
142139

143140

0 commit comments

Comments
 (0)