Skip to content

Commit 7427ef8

Browse files
author
angazenn
committed
modify torchair backend
Signed-off-by: angazenn <zengyanjia@huawei.com>
1 parent d1984d2 commit 7427ef8

File tree

2 files changed

+41
-31
lines changed

2 files changed

+41
-31
lines changed

tests/e2e/multicard/test_torchair_graph_mode.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,15 @@ def _pangu_torchair_test_fixture(
133133
# use greedy sampler to make sure the generated results are fix
134134
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
135135

136-
# NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of
137-
# DeepSeek-V3 with 2 hidden layers, thus the golden results seems
138-
# inaccurate. This will only change if accuracy improves with the
139-
# official weights of DeepSeek-V3.
136+
# NOTE: vllm-ascend/pangu-pro-moe-pruning is only part of PanguProMoE-72B
137+
# with 2 hidden layers, thus the golden results seems inaccurate.
138+
# This will only change if accuracy changes with the official weights
139+
# of PanguProMoE-72B.
140140
golden_results = [
141-
'Hello, my name is Remempondeprecatedmiot忱',
142-
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
143-
'The capital of France is Rememvoud administrativ Remem投',
144-
'The future of AI isotope Segnali Zoeken精细化 supus',
141+
'Hello, my name is Remempondeprecatedmiot忱', # noqa
142+
'The president of the United States is Remem下的一个 rever ceremoni Segnali', # noqa
143+
'The capital of France is Rememvoud administrativ Remem投', # noqa
144+
'The future of AI isotope Segnali Zoeken精细化 supus', # noqa
145145
]
146146

147147
assert len(golden_results) == len(vllm_output)

vllm_ascend/attention/attention_v1_torchair.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030
from vllm_ascend.ascend_config import get_ascend_config
3131
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl,
32-
AscendAttentionState)
32+
AscendAttentionState,
33+
AscendMetadata)
3334
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
3435
nd_to_nz_2d)
3536

@@ -365,7 +366,7 @@ def forward(
365366
key: torch.Tensor,
366367
value: torch.Tensor,
367368
kv_cache: torch.Tensor,
368-
attn_metadata: AscendTorchairMetadata,
369+
attn_metadata: AscendMetadata,
369370
output: Optional[torch.Tensor] = None,
370371
trace_flag: bool = False,
371372
) -> torch.Tensor:
@@ -410,11 +411,7 @@ def forward(
410411
raise NotImplementedError("Encoder self-attention and "
411412
"encoder/decoder cross-attention "
412413
"are not implemented for "
413-
"PallasAttentionBackendImpl")
414-
# View q k v to BSH.
415-
query = query.view(-1, self.num_heads, self.head_size)
416-
key = key.view(-1, self.num_kv_heads, self.head_size)
417-
value = value.view(-1, self.num_kv_heads, self.head_size)
414+
"AscendAttentionTorchairBackendImpl")
418415

419416
if kv_cache is not None and kv_cache[0].numel() > 0:
420417
key_cache, value_cache = kv_cache[0], kv_cache[1]
@@ -425,17 +422,19 @@ def forward(
425422
block_indices = slots_indices // block_size
426423
slots_indices = slots_indices % block_size
427424
indices = torch.cat((block_indices, slots_indices), dim=1)
428-
torch_npu.npu_scatter_nd_update_(
429-
key_cache, indices,
430-
key.view(-1, self.num_kv_heads * self.head_size))
431-
torch_npu.npu_scatter_nd_update_(
432-
value_cache, indices,
433-
value.view(-1, self.num_kv_heads * self.head_size))
425+
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
426+
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
434427

435428
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
436429
assert attn_metadata is not None
437430
assert attn_metadata.attn_mask is not None
438431
mask = attn_metadata.attn_mask
432+
433+
# View q k v to BSH.
434+
query = query.view(-1, self.num_heads, self.head_size)
435+
key = key.view(-1, self.num_kv_heads, self.head_size)
436+
value = value.view(-1, self.num_kv_heads, self.head_size)
437+
439438
if is_310p():
440439
# align q k v output tensors
441440
query = aligned_16(query)
@@ -458,6 +457,22 @@ def forward(
458457
num_kv_heads=self.num_kv_heads,
459458
out=output)
460459
output = output[:num_tokens, :, :]
460+
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
461+
assert attn_metadata is not None
462+
assert attn_metadata.attn_mask is not None
463+
compress_mask = attn_metadata.attn_mask
464+
torch_npu._npu_flash_attention_qlens(
465+
query=query,
466+
key_cache=self.key_cache,
467+
value_cache=self.value_cache,
468+
block_table=attn_metadata.block_tables,
469+
mask=compress_mask,
470+
seq_len=attn_metadata.query_lens,
471+
context_lens=attn_metadata.seq_lens,
472+
num_kv_heads=self.num_kv_heads,
473+
num_heads=self.num_heads,
474+
scale_value=self.scale,
475+
out=output)
461476
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
462477
decode_meta = attn_metadata.decode
463478
assert decode_meta is not None
@@ -478,15 +493,10 @@ def forward(
478493
input_layout='BSH',
479494
block_size=block_size)
480495
else:
481-
output = super().forward(
482-
layer=layer,
483-
query=query,
484-
key=key,
485-
value=value,
486-
kv_cache=kv_cache,
487-
attn_metadata=attn_metadata,
488-
output=output,
489-
trace_flag=trace_flag,
490-
)
496+
raise NotImplementedError(
497+
"Torchair graph mode with non-MLA attention backend is still experimental."
498+
"v1 scheduler(chunked prefill) is not supported at this moment. Please"
499+
"setting 'ascend_scheduler_config':{'enabled':true} in additional_config"
500+
"to use ascend scheduler.")
491501

492502
return output.view(num_tokens, self.hidden_size)

0 commit comments

Comments
 (0)