Skip to content

Commit f025df0

Browse files
[Attn][bugfix] fix long seq percision issue (vllm-project#334)
fix vllm-project/vllm-ascend#321 This pr is a temporary solution for long seq percision issue, will revert when the root cause is fixed cc @rjg-lyh @wangxiyuan Co-authored-by: rjg-lyh <1318825571@qq.com> --------- Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: wangxiyuan <wangxiyuan@huawei.com>
1 parent 6d0f1d8 commit f025df0

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

docs/source/user_guide/release_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
- Pin modelscope<1.23.0 on vLLM v0.7.3 to resolve: https://github.com/vllm-project/vllm/pull/13807
2626

2727
### Known issues
28-
- In [some cases](https://github.com/vllm-project/vllm-ascend/issues/321), expecially when the input/output is very long, the accuracy of output may be incorrect. You may see many `!` in the output, We are working on it. It'll be fixed in the next release.
28+
- In some cases, expecially when the input/output is very long with VL model, the accuracy of output may be incorrect. You may see many `!` or some other unreadable code in the output. We are working on it. It'll be fixed in the next release.
2929
- Improved and reduced the garbled code in model output. But if you still hit the issue, try to change the gerneration config value, such as `temperature` and try again. Any [feedback](https://github.com/vllm-project/vllm-ascend/issues/267) is welcome. [#277](https://github.com/vllm-project/vllm-ascend/pull/277)
3030

3131
## v0.7.1rc1

vllm_ascend/attention.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,8 @@ def __init__(
717717
self.query_len_cpu_tensor = None
718718
self.key_cache = None
719719
self.value_cache = None
720+
# TODO: FIXME revert me when torch-npu sync issue is solved
721+
self.output: torch.Tensor = None
720722

721723
def forward(
722724
self,
@@ -755,11 +757,11 @@ def forward(
755757
value = value.contiguous()
756758
attn_type = self.attn_type
757759

758-
output = torch.empty(num_tokens,
759-
self.num_heads,
760-
self.head_size,
761-
dtype=query.dtype,
762-
device=query.device)
760+
self.output = torch.empty(num_tokens,
761+
self.num_heads,
762+
self.head_size,
763+
dtype=query.dtype,
764+
device=query.device)
763765

764766
if kv_cache.numel() > 0:
765767
if self.key_cache is None:
@@ -792,7 +794,7 @@ def forward(
792794
block_tables,
793795
isPrefill,
794796
attn_metadata,
795-
output,
797+
self.output,
796798
seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
797799
else:
798800
if self.key_cache is not None:
@@ -831,7 +833,7 @@ def forward(
831833
is_causal=causal_attn and mask is None,
832834
scale=self.scale).squeeze(0).movedim(
833835
query.dim() - 2, 0)
834-
output[start_q:end_q, :, :] = sub_out
836+
self.output[start_q:end_q, :, :] = sub_out
835837
start_q, start_kv = end_q, end_kv
836838
else:
837839
assert attn_metadata.attn_mask is not None
@@ -849,7 +851,7 @@ def forward(
849851
scale_value=self.scale,
850852
num_heads=self.num_heads,
851853
num_kv_heads=self.num_kv_heads,
852-
out=output)
854+
out=self.output)
853855
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
854856
assert kv_cache is not None
855857
assert attn_metadata.prefill_metadata is not None
@@ -875,7 +877,7 @@ def forward(
875877
num_kv_heads=self.num_kv_heads,
876878
num_heads=self.num_heads,
877879
scale_value=self.scale,
878-
out=output)
880+
out=self.output)
879881
# Splitfuse
880882
else:
881883
assert kv_cache is not None
@@ -897,7 +899,7 @@ def forward(
897899
num_kv_heads=self.num_kv_heads,
898900
num_heads=self.num_heads,
899901
scale_value=self.scale,
900-
out=output)
902+
out=self.output)
901903
# Decode only
902904
else:
903905
assert kv_cache is not None
@@ -915,9 +917,9 @@ def forward(
915917
scale_value=self.scale,
916918
block_table=block_tables,
917919
context_lens=self.seq_lens_tensor_cpu,
918-
out=output)
920+
out=self.output)
919921

920-
return output.view(num_tokens, self.hidden_size)
922+
return self.output.view(num_tokens, self.hidden_size)
921923

922924

923925
class AscendMLAAttentionBackendImpl(MLAAttentionImpl):

0 commit comments

Comments
 (0)