diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 929c3b6a4906..fe9de65b52c6 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -346,7 +346,7 @@ def forward( if self.use_output: output_shape = output_shape if output_shape is not None else query.shape - output = torch.zeros(output_shape, dtype=output_dtype, device=query.device) + output = torch.empty(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] # Reshape the query, key, and value tensors. # NOTE(woosuk): We do this outside the custom op to minimize the @@ -705,7 +705,7 @@ def forward( self.calc_kv_scales(q, kv_c_normed, k_pe) if self.attn_backend.accept_output_buffer: - output = torch.zeros(output_shape, dtype=q.dtype, device=q.device) + output = torch.empty(output_shape, dtype=q.dtype, device=q.device) self.impl.forward( self, q, @@ -722,7 +722,7 @@ def forward( ) else: if self.attn_backend.accept_output_buffer: - output = torch.zeros(output_shape, dtype=q.dtype, device=q.device) + output = torch.empty(output_shape, dtype=q.dtype, device=q.device) torch.ops.vllm.unified_mla_attention_with_output( q, kv_c_normed, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fa4e34536135..9e0c125d9edb 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -530,7 +530,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) attn_type = self.attn_type diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0fa71afa62ee..ee32f7e2904f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -857,7 +857,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) if self.bmm1_scale is None: self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 2595851e5042..902872bb25b3 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -767,7 +767,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # query = self.view_as_4d(query).permute(0, 2, 1, 3) # return torch.empty_like(query) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index cce43b220da7..7c73611d4a58 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -485,7 +485,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 14184944934f..27b072106268 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -130,7 +130,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) assert attn_metadata.use_cascade is False diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 5245c7f44925..8b7ce90a3cca 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -299,7 +299,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) assert attn_metadata.use_cascade is False diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index aab90cfd1fe0..ee6ead9ad9b3 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -379,7 +379,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # Cache the input KVs. key_cache, value_cache = kv_cache.unbind(0) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 9d1d007a08e4..9746a0eb58bd 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -298,7 +298,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) assert attn_metadata.use_cascade is False diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 41c543c18adc..457b15ebdd82 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -354,7 +354,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # Cache the input KVs. key_cache, value_cache = kv_cache.unbind(0)