@@ -346,7 +346,7 @@ def forward(
346346
347347 if self .use_output :
348348 output_shape = output_shape if output_shape is not None else query .shape
349- output = torch .zeros (output_shape , dtype = output_dtype , device = query .device )
349+ output = torch .empty (output_shape , dtype = output_dtype , device = query .device )
350350 hidden_size = output_shape [- 1 ]
351351 # Reshape the query, key, and value tensors.
352352 # NOTE(woosuk): We do this outside the custom op to minimize the
@@ -705,7 +705,7 @@ def forward(
705705 self .calc_kv_scales (q , kv_c_normed , k_pe )
706706
707707 if self .attn_backend .accept_output_buffer :
708- output = torch .zeros (output_shape , dtype = q .dtype , device = q .device )
708+ output = torch .empty (output_shape , dtype = q .dtype , device = q .device )
709709 self .impl .forward (
710710 self ,
711711 q ,
@@ -722,7 +722,7 @@ def forward(
722722 )
723723 else :
724724 if self .attn_backend .accept_output_buffer :
725- output = torch .zeros (output_shape , dtype = q .dtype , device = q .device )
725+ output = torch .empty (output_shape , dtype = q .dtype , device = q .device )
726726 torch .ops .vllm .unified_mla_attention_with_output (
727727 q ,
728728 kv_c_normed ,
0 commit comments