Skip to content

Commit ddf54ef

Browse files
committed
feat: support v1 engine on 310P
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
1 parent 411d04b commit ddf54ef

File tree

7 files changed

+82
-14
lines changed

7 files changed

+82
-14
lines changed

vllm_ascend/attention/attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@
3636

3737
from vllm_ascend.ascend_config import get_ascend_config
3838
from vllm_ascend.ops.cache import concat_and_cache_mla
39-
from vllm_ascend.utils import enable_custom_op
40-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
41-
nd_to_nz_2d)
39+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
40+
enable_custom_op, is_310p, nd_to_nz_2d)
4241
from vllm_ascend.worker.model_runner import (
4342
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
4443

vllm_ascend/attention/attention_v1.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from vllm.v1.worker.gpu_input_batch import InputBatch
3131

3232
from vllm_ascend.ops.attention import vanilla_chunked_prefill
33+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
34+
nd_to_nz_2d, nd_to_nz_spec)
3335

3436

3537
class AscendAttentionBackend(AttentionBackend):
@@ -62,6 +64,9 @@ def get_kv_cache_shape(
6264
num_kv_heads: int,
6365
head_size: int,
6466
) -> Tuple[int, ...]:
67+
if is_310p():
68+
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
69+
16)
6570
return (2, num_blocks, block_size, num_kv_heads, head_size)
6671

6772
@staticmethod
@@ -167,6 +172,16 @@ def build(self,
167172
query_start_loc = query_start_loc_cpu.to(self.runner.device,
168173
non_blocking=True)
169174

175+
if is_310p():
176+
if attn_state == AscendAttentionState.PrefillNoCache:
177+
mask_nz = nd_to_nz_2d(attn_mask)
178+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
179+
ACL_FORMAT_FRACTAL_NZ)
180+
elif attn_state == AscendAttentionState.ChunkedPrefill:
181+
mask_nz = nd_to_nz_spec(attn_mask)
182+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
183+
ACL_FORMAT_FRACTAL_NZ)
184+
170185
attn_metadata = AscendMetadata(
171186
num_actual_tokens=num_actual_tokens,
172187
block_tables=block_table,
@@ -250,6 +265,7 @@ def forward(
250265
self.head_size,
251266
dtype=query.dtype,
252267
device=query.device)
268+
ori_output = output
253269
if trace_flag:
254270
torch.ops.vllm.unified_ascend_attention_with_output(
255271
query=query,
@@ -294,6 +310,18 @@ def forward(
294310
assert attn_metadata is not None
295311
assert attn_metadata.attn_mask is not None
296312
mask = attn_metadata.attn_mask
313+
if is_310p():
314+
# align q k v output tensors
315+
query = aligned_16(query)
316+
key = aligned_16(key)
317+
value = aligned_16(value)
318+
output = aligned_16(output)
319+
320+
# do reformat in case of broadcasted tensors
321+
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
322+
mask = torch_npu.npu_format_cast(mask.contiguous(),
323+
ACL_FORMAT_FRACTAL_NZ)
324+
297325
torch_npu._npu_flash_attention(query=query,
298326
key=key,
299327
value=value,
@@ -303,6 +331,7 @@ def forward(
303331
num_heads=self.num_heads,
304332
num_kv_heads=self.num_kv_heads,
305333
out=output)
334+
output = output[:num_tokens, :, :]
306335
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
307336
assert attn_metadata is not None
308337
assert attn_metadata.attn_mask is not None
@@ -320,6 +349,10 @@ def forward(
320349
scale_value=self.scale,
321350
out=output)
322351
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
352+
if is_310p():
353+
# # seq_lens_tensor needs to be transferred to the device for 310P
354+
attn_metadata.seq_lens = \
355+
attn_metadata.seq_lens.to(device=query.device)
323356
torch_npu._npu_paged_attention(
324357
query=query,
325358
key_cache=self.key_cache,
@@ -353,6 +386,14 @@ def forward(
353386
self.scale, None, True)
354387
else:
355388
# use paged attention
389+
assert attn_metadata is not None
390+
assert attn_metadata.attn_mask is not None
391+
if is_310p():
392+
# do reformat in case of broadcasted tensors
393+
attn_metadata.attn_mask = \
394+
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
395+
attn_metadata.seq_lens = \
396+
attn_metadata.seq_lens.to(device=query.device)
356397
torch_npu._npu_paged_attention_splitfuse(
357398
query=query,
358399
key_cache=self.key_cache,
@@ -365,6 +406,10 @@ def forward(
365406
num_heads=self.num_heads,
366407
scale_value=self.scale,
367408
out=output)
409+
410+
# to make in-place change to the output tensor
411+
if not id(ori_output) == id(output):
412+
ori_output[:, :, :] = output[:num_tokens, :, :]
368413
return output.view(num_tokens, self.hidden_size)
369414

370415

vllm_ascend/ops/rotary_embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
from vllm.model_executor.layers.rotary_embedding import (
2323
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
2424

25-
from vllm_ascend.utils import enable_custom_op
26-
from vllm_ascend.utils import is_310p
25+
from vllm_ascend.utils import enable_custom_op, is_310p
2726

2827

2928
def custom_rotary_embedding_enabled(query, neox_style, head_size):

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
# Adapted from vllm/model_executor/models/qwen2_vl.py
1818
# This file is a part of the vllm-ascend project.
1919

20+
import torch
2021
import vllm
2122
import vllm.distributed
2223
import vllm.envs as envs
2324
from torch.distributed import ProcessGroup
2425
from vllm.config import ParallelConfig
26+
from vllm.distributed.utils import \
27+
stateless_init_torch_distributed_process_group
2528
from vllm.logger import logger
2629

2730
from vllm_ascend.utils import NullHandle, is_310p
28-
from vllm.distributed.utils import \
29-
stateless_init_torch_distributed_process_group
3031

3132

3233
def ascend_destroy_model_parallel():

vllm_ascend/platform.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from vllm.platforms import Platform, PlatformEnum
2929

3030
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
31-
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
31+
from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p,
32+
update_aclgraph_sizes)
3233

3334
if TYPE_CHECKING:
3435
from vllm.config import ModelConfig, VllmConfig
@@ -205,8 +206,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
205206
cache_config.block_size = 128
206207

207208
if envs.VLLM_USE_V1:
208-
# Activate custom ops for v1.
209-
compilation_config.custom_ops = ["all"]
209+
# Activate custom ops for v1, except on 310P
210+
if not is_310p():
211+
compilation_config.custom_ops = ["all"]
210212

211213
# If ascend_scheduler_config is enabled,
212214
# extents original scheduler_config to use AscendScheduler.

vllm_ascend/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,21 @@ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
116116
2).contiguous()
117117

118118

119+
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
120+
num_tokens = mask_tensor.shape[0]
121+
max_seq_len = mask_tensor.shape[1]
122+
123+
tokens_pad = (num_tokens + 15) // 16 * 16
124+
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
125+
126+
mask_tensor_pad = \
127+
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
128+
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
129+
mask = mask_tensor_pad.reshape(
130+
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
131+
return mask
132+
133+
119134
def aligned_16(tensor: torch.Tensor):
120135
"""Aligned tensor for 310P"""
121136

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@
7474
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
7575
from vllm_ascend.platform import NPUPlatform
7676
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
77-
from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is
77+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
78+
ProfileExecuteDuration, is_310p,
79+
vllm_version_is)
7880
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
7981

8082
if TYPE_CHECKING:
@@ -1641,6 +1643,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
16411643
cache size of each layer
16421644
"""
16431645
import torch_npu
1646+
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
1647+
) else ACL_FORMAT_FRACTAL_ND
16441648
kv_caches: Dict[str, torch.Tensor] = {}
16451649

16461650
self.input_batch = InputBatch(
@@ -1698,13 +1702,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
16981702
device=self.device)
16991703
kv_caches[layer_name] = (layer_kv_cache_nope,
17001704
layer_kv_cache_pe)
1701-
torch_npu.npu_format_cast(kv_caches[layer_name][0], 2)
1702-
torch_npu.npu_format_cast(kv_caches[layer_name][1], 2)
1705+
kv_caches[layer_name][0] = \
1706+
torch_npu.npu_format_cast(kv_caches[layer_name][0], acl_format)
1707+
kv_caches[layer_name][1] = \
1708+
torch_npu.npu_format_cast(kv_caches[layer_name][1], acl_format)
17031709
else:
17041710
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
17051711
dtype=dtype,
17061712
device=self.device)
1707-
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
1713+
kv_caches[layer_name] = \
1714+
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
17081715
else:
17091716
# TODO: add new branches when introducing more types of
17101717
# KV cache specs.

0 commit comments

Comments
 (0)