Skip to content

Commit 0f2bdde

Browse files
committed
fix: ascend_scheduler adapt to v0.9.0
Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent df58fb8 commit 0f2bdde

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

vllm_ascend/core/scheduler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,15 @@ def skip_cur_request():
130130

131131
assert num_new_tokens > 0
132132
watermark = getattr(self.scheduler_config, "watermark", 0.01)
133-
if not self._check_watermark_for_prefill(
134-
request, num_new_tokens, computed_blocks, watermark):
133+
if not self._check_watermark_for_prefill(request, num_new_tokens,
134+
computed_blocks.blocks,
135+
watermark):
135136
# Scheduling would exceed watermark, skip.
136137
skip_cur_request()
137138
continue
138139

139140
new_blocks = self.kv_cache_manager.allocate_slots(
140-
request, num_new_tokens, computed_blocks)
141+
request, num_new_tokens, num_computed_blocks=computed_blocks)
141142
if new_blocks is None:
142143
# The request cannot be scheduled.
143144
break
@@ -155,9 +156,8 @@ def skip_cur_request():
155156

156157
if self.lora_config and request.lora_request:
157158
scheduled_loras.add(request.lora_request.lora_int_id)
158-
req_to_new_block_ids[request.request_id] = [
159-
b.block_id for b in computed_blocks + new_blocks
160-
]
159+
req_to_new_block_ids[request.request_id] = (
160+
self.kv_cache_manager.get_block_ids(request.request_id))
161161
# Update request info.
162162
num_scheduled_tokens[request.request_id] = num_new_tokens
163163
token_budget -= num_new_tokens
@@ -215,9 +215,8 @@ def skip_cur_request():
215215
# Schedule the request.
216216
scheduled_running_reqs.append(request)
217217
self.scheduled_req_ids.add(request.request_id)
218-
req_to_new_block_ids[request.request_id] = [
219-
b.block_id for b in new_blocks
220-
]
218+
req_to_new_block_ids[request.request_id] = (
219+
new_blocks.get_block_ids())
221220
num_scheduled_tokens[request.request_id] = num_new_tokens
222221
token_budget -= num_new_tokens
223222
req_index += 1
@@ -326,7 +325,8 @@ def _check_watermark_for_prefill(self,
326325
len(computed_blocks) * self.block_size)
327326
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
328327
self.block_size)
329-
req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id]
328+
req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[
329+
request.request_id]
330330
num_new_blocks = (num_required_blocks - len(req_blocks) -
331331
len(computed_blocks))
332332
num_evictable_computed_blocks = sum(1 for blk in computed_blocks

0 commit comments

Comments
 (0)