Skip to content

Commit a0a8296

Browse files
BoyuanFenggemini-code-assist[bot]
authored andcommitted
remove attn output view kernel (vllm-project#26680)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent f52cc28 commit a0a8296

File tree

10 files changed

+12
-12
lines changed

10 files changed

+12
-12
lines changed

vllm/attention/layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def forward(
346346

347347
if self.use_output:
348348
output_shape = output_shape if output_shape is not None else query.shape
349-
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
349+
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
350350
hidden_size = output_shape[-1]
351351
# Reshape the query, key, and value tensors.
352352
# NOTE(woosuk): We do this outside the custom op to minimize the
@@ -705,7 +705,7 @@ def forward(
705705
self.calc_kv_scales(q, kv_c_normed, k_pe)
706706

707707
if self.attn_backend.accept_output_buffer:
708-
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
708+
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
709709
self.impl.forward(
710710
self,
711711
q,
@@ -722,7 +722,7 @@ def forward(
722722
)
723723
else:
724724
if self.attn_backend.accept_output_buffer:
725-
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
725+
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
726726
torch.ops.vllm.unified_mla_attention_with_output(
727727
q,
728728
kv_c_normed,

vllm/v1/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def forward(
530530

531531
if attn_metadata is None:
532532
# Profiling run.
533-
return output
533+
return output.fill_(0)
534534

535535
attn_type = self.attn_type
536536

vllm/v1/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ def forward(
857857

858858
if attn_metadata is None:
859859
# Profiling run.
860-
return output
860+
return output.fill_(0)
861861

862862
if self.bmm1_scale is None:
863863
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale

vllm/v1/attention/backends/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def forward(
767767

768768
if attn_metadata is None:
769769
# Profiling run.
770-
return output
770+
return output.fill_(0)
771771
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
772772
# return torch.empty_like(query)
773773

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def forward(
485485

486486
if attn_metadata is None:
487487
# Profiling run.
488-
return output
488+
return output.fill_(0)
489489

490490
# IMPORTANT!
491491
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in

vllm/v1/attention/backends/rocm_aiter_unified_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def forward(
130130

131131
if attn_metadata is None:
132132
# Profiling run.
133-
return output
133+
return output.fill_(0)
134134

135135
assert attn_metadata.use_cascade is False
136136

vllm/v1/attention/backends/rocm_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def forward(
299299

300300
if attn_metadata is None:
301301
# Profiling run.
302-
return output
302+
return output.fill_(0)
303303

304304
assert attn_metadata.use_cascade is False
305305

vllm/v1/attention/backends/tree_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def forward(
379379

380380
if attn_metadata is None:
381381
# Profiling run.
382-
return output
382+
return output.fill_(0)
383383

384384
# Cache the input KVs.
385385
key_cache, value_cache = kv_cache.unbind(0)

vllm/v1/attention/backends/triton_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def forward(
298298

299299
if attn_metadata is None:
300300
# Profiling run.
301-
return output
301+
return output.fill_(0)
302302

303303
assert attn_metadata.use_cascade is False
304304

vllm/v1/attention/backends/xformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def forward(
354354

355355
if attn_metadata is None:
356356
# Profiling run.
357-
return output
357+
return output.fill_(0)
358358

359359
# Cache the input KVs.
360360
key_cache, value_cache = kv_cache.unbind(0)

0 commit comments

Comments
 (0)