Skip to content

Commit 7d0b0f6

Browse files
committed
feat: support v1 engine on 310P
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
1 parent 5ae959c commit 7d0b0f6

File tree

4 files changed

+73
-6
lines changed

4 files changed

+73
-6
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from vllm.v1.worker.gpu_input_batch import InputBatch
3131

3232
from vllm_ascend.ops.attention import vanilla_chunked_prefill
33+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
34+
nd_to_nz_2d, nd_to_nz_spec)
3335

3436

3537
class AscendAttentionBackend(AttentionBackend):
@@ -62,6 +64,9 @@ def get_kv_cache_shape(
6264
num_kv_heads: int,
6365
head_size: int,
6466
) -> Tuple[int, ...]:
67+
if is_310p():
68+
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
69+
16)
6570
return (2, num_blocks, block_size, num_kv_heads, head_size)
6671

6772
@staticmethod
@@ -160,6 +165,16 @@ def build(self, num_reqs, num_actual_tokens, max_query_len,
160165
query_start_loc = query_start_loc_cpu.to(self.runner.device,
161166
non_blocking=True)
162167

168+
if is_310p():
169+
if attn_state == AscendAttentionState.PrefillNoCache:
170+
mask_nz = nd_to_nz_2d(attn_mask)
171+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
172+
ACL_FORMAT_FRACTAL_NZ)
173+
elif attn_state == AscendAttentionState.ChunkedPrefill:
174+
mask_nz = nd_to_nz_spec(attn_mask)
175+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
176+
ACL_FORMAT_FRACTAL_NZ)
177+
163178
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
164179
block_tables=block_table,
165180
query_start_loc=query_start_loc,
@@ -240,6 +255,7 @@ def forward(
240255
self.head_size,
241256
dtype=query.dtype,
242257
device=query.device)
258+
ori_output = output
243259
if trace_flag:
244260
torch.ops.vllm.unified_ascend_attention_with_output(
245261
query=query,
@@ -284,6 +300,18 @@ def forward(
284300
assert attn_metadata is not None
285301
assert attn_metadata.attn_mask is not None
286302
mask = attn_metadata.attn_mask
303+
if is_310p():
304+
# align q k v output tensors
305+
query = aligned_16(query)
306+
key = aligned_16(key)
307+
value = aligned_16(value)
308+
output = aligned_16(output)
309+
310+
# do reformat in case of broadcasted tensors
311+
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
312+
mask = torch_npu.npu_format_cast(mask.contiguous(),
313+
ACL_FORMAT_FRACTAL_NZ)
314+
287315
torch_npu._npu_flash_attention(query=query,
288316
key=key,
289317
value=value,
@@ -293,6 +321,7 @@ def forward(
293321
num_heads=self.num_heads,
294322
num_kv_heads=self.num_kv_heads,
295323
out=output)
324+
output = output[:num_tokens, :, :]
296325
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
297326
assert attn_metadata is not None
298327
assert attn_metadata.attn_mask is not None
@@ -310,6 +339,10 @@ def forward(
310339
scale_value=self.scale,
311340
out=output)
312341
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
342+
if is_310p():
343+
# # seq_lens_tensor needs to be transferred to the device for 310P
344+
attn_metadata.seq_lens = \
345+
attn_metadata.seq_lens.to(device=self.key_cache.device)
313346
torch_npu._npu_paged_attention(
314347
query=query,
315348
key_cache=self.key_cache,
@@ -343,6 +376,12 @@ def forward(
343376
self.scale, None, True)
344377
else:
345378
# use paged attention
379+
if is_310p():
380+
# do reformat in case of broadcasted tensors
381+
attn_metadata.attn_mask = \
382+
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
383+
attn_metadata.seq_lens = \
384+
attn_metadata.seq_lens.to(device=self.key_cache.device)
346385
torch_npu._npu_paged_attention_splitfuse(
347386
query=query,
348387
key_cache=self.key_cache,
@@ -355,6 +394,10 @@ def forward(
355394
num_heads=self.num_heads,
356395
scale_value=self.scale,
357396
out=output)
397+
398+
# to make in-place change to the output tensor
399+
if not id(ori_output) == id(output):
400+
ori_output[:, :, :] = output[:num_tokens, :, :]
358401
return output.view(num_tokens, self.hidden_size)
359402

360403

vllm_ascend/platform.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from vllm.logger import logger
2525
from vllm.platforms import Platform, PlatformEnum
2626

27-
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
27+
from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p,
28+
update_aclgraph_sizes)
2829

2930
CUSTOM_OP_ENABLED = False
3031
try:
@@ -219,8 +220,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
219220
cache_config.block_size = 128
220221

221222
if envs.VLLM_USE_V1:
222-
# Activate custom ops for v1.
223-
compilation_config.custom_ops = ["all"]
223+
# Activate custom ops for v1, except on 310P
224+
if not is_310p():
225+
compilation_config.custom_ops = ["all"]
224226
# If ascend_scheduler_config exists in additional_config,
225227
# extents original scheduler_config to use AscendScheduler.
226228

vllm_ascend/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
100100
2).contiguous()
101101

102102

103+
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
104+
num_tokens = mask_tensor.shape[0]
105+
max_seq_len = mask_tensor.shape[1]
106+
107+
tokens_pad = (num_tokens + 15) // 16 * 16
108+
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
109+
110+
mask_tensor_pad = \
111+
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
112+
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
113+
mask = mask_tensor_pad.reshape(
114+
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
115+
return mask
116+
117+
103118
def aligned_16(tensor: torch.Tensor):
104119
"""Aligned tensor for 310P"""
105120

vllm_ascend/worker/model_runner_v1.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
6565
from vllm_ascend.platform import NPUPlatform
6666
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
67+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
68+
is_310p)
6769
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
6870

6971
if TYPE_CHECKING:
@@ -1263,6 +1265,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
12631265
cache size of each layer
12641266
"""
12651267
import torch_npu
1268+
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
1269+
) else ACL_FORMAT_FRACTAL_ND
12661270
kv_caches: Dict[str, torch.Tensor] = {}
12671271

12681272
self.input_batch = InputBatch(
@@ -1312,13 +1316,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
13121316
device=self.device)
13131317
kv_caches[layer_name] = (layer_kv_cache_nope,
13141318
layer_kv_cache_pe)
1315-
torch_npu.npu_format_cast(kv_caches[layer_name][0], 2)
1316-
torch_npu.npu_format_cast(kv_caches[layer_name][1], 2)
1319+
kv_caches[layer_name][0] = \
1320+
torch_npu.npu_format_cast(kv_caches[layer_name][0], acl_format)
1321+
kv_caches[layer_name][1] = \
1322+
torch_npu.npu_format_cast(kv_caches[layer_name][1], acl_format)
13171323
else:
13181324
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
13191325
dtype=dtype,
13201326
device=self.device)
1321-
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
1327+
kv_caches[layer_name] = \
1328+
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
13221329
else:
13231330
# TODO: add new branches when introducing more types of
13241331
# KV cache specs.

0 commit comments

Comments
 (0)