Skip to content

Commit 86987c6

Browse files
whx-sjtuhw_whx
authored andcommitted
fix ci problems
Signed-off-by: hw_whx <2952154980@qq.com>
1 parent 8baf750 commit 86987c6

File tree

9 files changed

+184
-119
lines changed

9 files changed

+184
-119
lines changed

tests/test_scheduler.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,32 @@
1-
# SPDX-License-Identifier: Apache-2.0
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
219
from typing import List, Optional
320

421
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
522
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
623
from vllm.sampling_params import SamplingParams
7-
from vllm_ascend.core.scheduler import AscendScheduler
824
from vllm.v1.core.scheduler import SchedulerOutput
925
from vllm.v1.outputs import ModelRunnerOutput
1026
from vllm.v1.request import Request, RequestStatus
1127

28+
from vllm_ascend.core.scheduler import AscendScheduler
29+
1230
EOS_TOKEN_ID = 50256
1331

1432

@@ -39,11 +57,11 @@ def create_scheduler(
3957
)
4058
cache_config.num_gpu_blocks = 10000
4159
return AscendScheduler(scheduler_config,
42-
model_config,
43-
cache_config,
44-
speculative_config=None,
45-
lora_config=None,
46-
log_stats=True)
60+
model_config,
61+
cache_config,
62+
speculative_config=None,
63+
lora_config=None,
64+
log_stats=True)
4765

4866

4967
def create_requests(
@@ -136,7 +154,6 @@ def test_schedule():
136154
assert scheduler.running[i] == request
137155

138156

139-
140157
def test_stop_via_update_from_output():
141158
"""Test stopping behavior through update_from_output"""
142159
scheduler = create_scheduler()
@@ -167,10 +184,8 @@ def test_stop_via_update_from_output():
167184

168185
model_output = ModelRunnerOutput(
169186
req_ids=[req.request_id for req in requests],
170-
req_id_to_index={
171-
req.request_id: i
172-
for i, req in enumerate(requests)
173-
},
187+
req_id_to_index={req.request_id: i
188+
for i, req in enumerate(requests)},
174189
sampled_token_ids=[[EOS_TOKEN_ID],
175190
[10,
176191
11]], # First request hits EOS, second continues
@@ -217,10 +232,8 @@ def test_stop_via_update_from_output():
217232

218233
model_output = ModelRunnerOutput(
219234
req_ids=[req.request_id for req in requests],
220-
req_id_to_index={
221-
req.request_id: i
222-
for i, req in enumerate(requests)
223-
},
235+
req_id_to_index={req.request_id: i
236+
for i, req in enumerate(requests)},
224237
sampled_token_ids=[[10, 42, 12],
225238
[13, 14]], # First request hits stop token
226239
spec_token_ids=None,
@@ -265,10 +278,8 @@ def test_stop_via_update_from_output():
265278

266279
model_output = ModelRunnerOutput(
267280
req_ids=[req.request_id for req in requests],
268-
req_id_to_index={
269-
req.request_id: i
270-
for i, req in enumerate(requests)
271-
},
281+
req_id_to_index={req.request_id: i
282+
for i, req in enumerate(requests)},
272283
sampled_token_ids=[[10, 11, 12],
273284
[13]], # First request exceeds max_tokens
274285
spec_token_ids=None,

vllm_ascend/attention/attention.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_decode_attn_mask(
9999
self.update_attn_cache(max_s, dtype, device)
100100
return (self.attn_mask_cache.index_select(
101101
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
102-
102+
103103
def get_splitfuse_attn_mask(
104104
self,
105105
seq_lens,
@@ -115,16 +115,15 @@ def get_splitfuse_attn_mask(
115115
# is not the same. Fix this in the future when kernel is ready.
116116
if self.attn_mask_cache[0][1] > 0:
117117
attn_mask = self.get_attn_mask( # type: ignore
118-
max_seq_len, dtype, device)
118+
max_seq_len, dtype, device)
119119
attn_mask *= -10000
120120
else:
121121
attn_mask = self.attn_mask_cache
122-
return torch.index_select(attn_mask,
123-
dim=0,
124-
index=position)[:, :max_seq_len]
122+
return torch.index_select(attn_mask, dim=0,
123+
index=position)[:, :max_seq_len]
125124
total_q_len = sum(query_lens)
126125
attn_mask = torch.zeros((total_q_len, max_seq_len),
127-
dtype=self.vllm_config.model_config.dtype,
126+
dtype=dtype,
128127
device="cpu")
129128

130129
current_row = 0
@@ -142,7 +141,7 @@ def get_splitfuse_attn_mask(
142141
right_tensor.tril() == self.splitfuse_mask_value, 0)
143142
current_row += q_len
144143

145-
return attn_mask.to(self.device, non_blocking=True)
144+
return attn_mask.to(device, non_blocking=True)
146145

147146

148147
class AscendAttentionBackend(AttentionBackend):

vllm_ascend/attention/attention_v1.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from enum import Enum
2020
from typing import Any, Dict, List, Optional, Tuple, Type
2121

22-
import numpy as np
23-
2422
import torch
2523
import torch_npu
2624
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -220,16 +218,15 @@ def forward(
220218
assert attn_metadata is not None
221219
assert attn_metadata.attn_mask is not None
222220
mask = attn_metadata.attn_mask
223-
torch_npu._npu_flash_attention(
224-
query=query,
225-
key=key,
226-
value=value,
227-
mask=mask,
228-
seq_len=attn_metadata.seq_lens,
229-
scale_value=self.scale,
230-
num_heads=self.num_heads,
231-
num_kv_heads=self.num_kv_heads,
232-
out=output)
221+
torch_npu._npu_flash_attention(query=query,
222+
key=key,
223+
value=value,
224+
mask=mask,
225+
seq_len=attn_metadata.seq_lens,
226+
scale_value=self.scale,
227+
num_heads=self.num_heads,
228+
num_kv_heads=self.num_kv_heads,
229+
out=output)
233230
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
234231
block_tables = attn_metadata.block_tables
235232
torch_npu._npu_paged_attention(
Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
1-
from dataclasses import dataclass, asdict
2-
from typing import Union, Type
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from dataclasses import asdict, dataclass
19+
from typing import Type, Union
320

421
from vllm.config import SchedulerConfig
522

@@ -9,34 +26,42 @@ class AscendSchedulerConfig(SchedulerConfig):
926
enable_chunked_prefill: bool = False
1027
policy: str = "fcfs"
1128
num_scheduler_steps: int = 1
12-
scheduler_cls: Union[str, Type[object]] = "vllm_ascend.core.scheduler.AscendScheduler"
13-
29+
scheduler_cls: Union[
30+
str, Type[object]] = "vllm_ascend.core.scheduler.AscendScheduler"
1431

1532
@classmethod
16-
def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig, ascend_scheduler_config: dict):
33+
def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig,
34+
ascend_scheduler_config: dict):
1735
scheduler_config = asdict(vllm_scheduler_config)
1836
# Override default values into original SchedulerConfig
1937
scheduler_config["enable_chunked_prefill"] = False
2038
scheduler_config["policy"] = "fcfs"
2139
scheduler_config["num_scheduler_steps"] = 1
22-
scheduler_config["scheduler_cls"] = "vllm_ascend.core.scheduler.AscendScheduler"
40+
scheduler_config[
41+
"scheduler_cls"] = "vllm_ascend.core.scheduler.AscendScheduler"
2342
# Override params in original SchedulerConfig with params in additional_config.ascend_scheduler_config
2443
for k, v in ascend_scheduler_config.items():
2544
scheduler_config[k] = v
2645
# The "chunked_prefill_enabled" param of vllm's SchedulerConfig can't be initialized.
2746
scheduler_config.pop("chunked_prefill_enabled")
2847
return cls(**scheduler_config)
2948

30-
3149
def __post_init__(self) -> None:
3250
self.chunked_prefill_enabled = self.enable_chunked_prefill
3351
if self.policy != "fcfs":
34-
raise NotImplementedError(f"currently AscendScheduler only supports fcfs policy, got {self.policy}")
52+
raise NotImplementedError(
53+
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
54+
)
3555
if self.is_multimodal_model:
36-
raise NotImplementedError(f"currently AscendScheduler only supports LLM modles.")
56+
raise NotImplementedError(
57+
"currently AscendScheduler only supports LLM modles.")
3758
if self.num_scheduler_steps > 1:
38-
raise NotImplementedError(f"currently AscendScheduler doesn't support multi-step.")
59+
raise NotImplementedError(
60+
"currently AscendScheduler doesn't support multi-step.")
3961
if self.send_delta_data:
40-
raise NotImplementedError(f"currently AscendScheduler doesn't support send_delta_data.")
62+
raise NotImplementedError(
63+
"currently AscendScheduler doesn't support send_delta_data.")
4164
if self.delay_factor > 0:
42-
raise NotImplementedError(f"currently AscendScheduler doesn't support scheduler_delay_factor.")
65+
raise NotImplementedError(
66+
"currently AscendScheduler doesn't support scheduler_delay_factor."
67+
)

0 commit comments

Comments
 (0)