diff --git a/tests/singlecard/test_scheduler.py b/tests/singlecard/test_scheduler.py index e22be2e540..d1c6062783 100644 --- a/tests/singlecard/test_scheduler.py +++ b/tests/singlecard/test_scheduler.py @@ -25,7 +25,7 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager @@ -88,14 +88,26 @@ def create_scheduler( model_config=model_config, cache_config=cache_config) - kv_cache_config = KVCacheConfig( - num_blocks=10000, # A large number of blocks to hold all requests - tensors={}, - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(16, 1, 1, torch.float32, False)) - ], - ) + if vllm_version_is("0.9.0"): + kv_cache_config = KVCacheConfig( + num_blocks=10000, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(16, 1, 1, torch.float32, + False)) + ], + ) + else: + kv_cache_config = KVCacheConfig( + num_blocks=10000, # A large number of blocks to hold all requests + kv_cache_tensors=[KVCacheTensor(size=1024, shared_by=[1])], + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(16, 1, 1, torch.float32, + False, None)) + ], + ) cache_config.num_gpu_blocks = 10000 return AscendScheduler( vllm_config, diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 0135d953f8..42f5d9c69a 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -29,6 +29,8 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager +from vllm_ascend.utils import vllm_version_is + class AscendScheduler(Scheduler): """This Scheduler extends vllm's original v1 scheduler @@ -127,10 +129,15 @@ def skip_cur_request(): continue assert num_new_tokens > 0 + + if vllm_version_is("0.9.0"): + blocks = computed_blocks.blocks + else: + blocks = computed_blocks.blocks[0] + watermark = getattr(self.scheduler_config, "watermark", 0.01) if not self._check_watermark_for_prefill(request, num_new_tokens, - computed_blocks.blocks, - watermark): + blocks, watermark): # Scheduling would exceed watermark, skip. skip_cur_request() continue @@ -323,8 +330,14 @@ def _check_watermark_for_prefill(self, len(computed_blocks) * self.block_size) num_required_blocks = cdiv(num_new_tokens + num_computed_tokens, self.block_size) - req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[ - request.request_id] + + if vllm_version_is("0.9.0"): + req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[ + request.request_id] + else: + req_blocks = self.kv_cache_manager.coordinator.get_blocks( + request.request_id) + num_new_blocks = (num_required_blocks - len(req_blocks) - len(computed_blocks)) num_evictable_computed_blocks = sum(1 for blk in computed_blocks diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 647176c4ee..07ea679312 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1321,12 +1321,25 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: block_sizes=[self.cache_config.block_size], ) + if not vllm_version_is("0.9.0"): + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "NPU.") + kv_cache_sizes[ + kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes + if vllm_version_is("0.9.0"): + tensor_size = kv_cache_config.tensors[layer_name].size + else: + tensor_size = kv_cache_sizes[layer_name] + assert tensor_size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_size // kv_cache_spec.page_size_bytes + # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that # KVCacheManager may allocate.