diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000000..a68decfe91 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,335 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List, Optional + +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus + +from vllm_ascend.core.scheduler import AscendScheduler + +EOS_TOKEN_ID = 50256 + + +def create_scheduler( + model: str = "/data/weights/Qwen2.5-72B-Instruct", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, +) -> AscendScheduler: + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + ) + cache_config.num_gpu_blocks = 10000 + return AscendScheduler(scheduler_config, + model_config, + cache_config, + speculative_config=None, + lora_config=None, + log_stats=True) + + +def create_requests( + num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[List[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[List[int]] = None, +): + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt=None, + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + requests.append(request) + return requests + + +def test_add_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + + for i, request in enumerate(requests): + scheduler.add_request(request) + assert request.request_id in scheduler.requests + assert len(scheduler.waiting) == i + 1 + + +def test_finish_request(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_ABORTED) + assert request.request_id not in scheduler.requests + assert len(scheduler.waiting) == 9 - i + + +def test_get_num_unfinished_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_STOPPED) + assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 + + +def test_schedule(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == len(requests) + for i, request in enumerate(requests): + assert scheduler.running[i] == request + + +def test_stop_via_update_from_output(): + """Test stopping behavior through update_from_output""" + scheduler = create_scheduler() + + # Test case 1: Stop on EOS token + requests = create_requests(num_requests=2, max_tokens=10) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[]) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[EOS_TOKEN_ID], + [10, + 11]], # First request hits EOS, second continues + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped, second continues + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_STOPPED + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID] + assert list(requests[1].output_token_ids) == [10, 11] + + # Test case 2: Stop on custom stop token + scheduler = create_scheduler() + requests = create_requests(num_requests=2, + max_tokens=10, + stop_token_ids=[42, 43]) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[]) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 42, 12], + [13, 14]], # First request hits stop token + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped on custom token + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_STOPPED + assert requests[0].stop_reason == 42 + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [10, 42] + assert list(requests[1].output_token_ids) == [13, 14] + + # Test case 3: Stop on max tokens + scheduler = create_scheduler() + requests = create_requests(num_requests=2, max_tokens=2) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[]) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 11, 12], + [13]], # First request exceeds max_tokens + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped due to length + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [10, 11 + ] # Truncated to max_tokens + assert list(requests[1].output_token_ids) == [13] + + # Test case 4: Ignore EOS flag + scheduler = create_scheduler() + requests = create_requests(num_requests=1, max_tokens=10) + requests[0].sampling_params.ignore_eos = True + requests[0].num_computed_tokens = requests[0].num_tokens + scheduler.requests[requests[0].request_id] = requests[0] + scheduler.running.append(requests[0]) + scheduler.scheduled_req_ids.add(requests[0].request_id) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[]) + + model_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify request continues past EOS + assert len(scheduler.running) == 1 + assert not requests[0].is_finished() + assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 29fc1dcdf0..588b43ca65 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -43,7 +43,7 @@ ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) -def generate_attn_mask(max_seq_len: int, dtype=torch.float16): +def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None): # Construct lower triangle matrix. mask_flag = torch.tril( torch.ones((max_seq_len, max_seq_len), @@ -52,10 +52,11 @@ def generate_attn_mask(max_seq_len: int, dtype=torch.float16): mask_flag = ~mask_flag # Currently for fp16 dtype, the mask value should be set to -inf. # TODO: Eliminate this part in the future. - if dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 + if mask_value is None: + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value).to(dtype) return attn_mask @@ -66,12 +67,14 @@ class AttentionMaskBuilder: def __init__(self, attn_mask: torch.Tensor): self._seq_len_cached = attn_mask.shape[0] self.attn_mask_cache = attn_mask + self.splitfuse_mask_value = -10000 @classmethod def initialize_from_len(cls, max_seq_len: int, - dtype: torch.dtype = torch.float16): - return cls(generate_attn_mask(max_seq_len, dtype)) + dtype: torch.dtype = torch.float16, + mask_value: Optional[int] = None): + return cls(generate_attn_mask(max_seq_len, dtype, mask_value)) def update_attn_cache(self, seqlen: int, dtype: torch.dtype, device: torch.device): @@ -97,6 +100,49 @@ def get_decode_attn_mask( return (self.attn_mask_cache.index_select( 0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous()) + def get_splitfuse_attn_mask( + self, + seq_lens, + query_lens, + position, + dtype, + device, + ) -> torch.Tensor: + max_seq_len = max(seq_lens, default=0) + if max_seq_len <= self._seq_len_cached: + self.update_attn_cache(max_seq_len, dtype, device) + # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation + # is not the same. Fix this in the future when kernel is ready. + if self.attn_mask_cache[0][1] > 0: + attn_mask = self.get_attn_mask( # type: ignore + max_seq_len, dtype, device) + attn_mask *= -10000 + else: + attn_mask = self.attn_mask_cache + return torch.index_select(attn_mask, dim=0, + index=position)[:, :max_seq_len] + total_q_len = sum(query_lens) + attn_mask = torch.zeros((total_q_len, max_seq_len), + dtype=dtype, + device="cpu") + + current_row = 0 + for i in range(len(query_lens)): + seq_len = seq_lens[i] + q_len = query_lens[i] + context_len = seq_len - q_len + + assert context_len >= 0 + attn_mask[current_row:current_row + q_len, + context_len:] = self.splitfuse_mask_value + right_tensor = attn_mask[current_row:current_row + q_len, + context_len:seq_len] + right_tensor.mask_fill_( + right_tensor.tril() == self.splitfuse_mask_value, 0) + current_row += q_len + + return attn_mask.to(device, non_blocking=True) + class AscendAttentionBackend(AttentionBackend): diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 1416e0cd8c..1708ba0161 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -16,6 +16,7 @@ # from dataclasses import dataclass +from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -50,7 +51,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return (2, num_blocks, block_size, num_kv_heads * head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -83,6 +84,12 @@ def copy_blocks( value_caches[dst_indices] = value_caches[src_indices] +class AscendAttentionState(Enum): + PrefillOnly = 0 + DecodeOnly = 1 + ChunkedPrefill = 2 + + @dataclass class AscendMetadata: # (batch_size, max_blocks_per_seq). @@ -104,6 +111,8 @@ class AscendMetadata: # FlashAttention has better performance than PageAtttention, # but it does not support decode requests. is_only_prefill: bool = False + # Current state of this attention run. + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_mask: Optional[torch.Tensor] = None @@ -139,7 +148,8 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.seq_len_cpu_tensor = None + self.key_cache = None + self.value_cache = None def forward( self, @@ -190,30 +200,52 @@ def forward( # TODO: Remove this contiguous in the future. value = value.contiguous() + if kv_cache.numel() > 0: + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + torch_npu._npu_reshape_and_cache(key=key, + value=value, + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slots) + if hasattr(layer, 'quant_method'): # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata pass + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + block_tables = attn_metadata.block_tables + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=block_tables, + context_lens=attn_metadata.context_lens, + out=output) + # Normal V1 situation. else: - if kv_cache.numel() > 0: - key_cache, value_cache = kv_cache[0], kv_cache[1] - num_blocks, block_size, _ = key_cache.shape - key_cache = key_cache.view(num_blocks, block_size, - self.num_kv_heads, self.head_size) - value_cache = value_cache.view(num_blocks, block_size, - self.num_kv_heads, - self.head_size) - slots = attn_metadata.slot_mapping - torch_npu._npu_reshape_and_cache(key=key, - value=value, - key_cache=key_cache, - value_cache=value_cache, - slot_indices=slots) - # use paged attention torch_npu._npu_paged_attention_splitfuse( query=query, - key_cache=key_cache, - value_cache=value_cache, + key_cache=self.key_cache, + value_cache=self.value_cache, mask=attn_metadata.attn_mask, block_table=attn_metadata.block_tables, seq_len=attn_metadata.seq_lens, diff --git a/vllm_ascend/core/__init__.py b/vllm_ascend/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py new file mode 100644 index 0000000000..f716aff061 --- /dev/null +++ b/vllm_ascend/core/schedule_config.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import asdict, dataclass +from typing import Type, Union + +from vllm.config import SchedulerConfig + + +@dataclass +class AscendSchedulerConfig(SchedulerConfig): + enable_chunked_prefill: bool = False + policy: str = "fcfs" + num_scheduler_steps: int = 1 + scheduler_cls: Union[ + str, Type[object]] = "vllm_ascend.core.scheduler.AscendScheduler" + + @classmethod + def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig, + ascend_scheduler_config: dict): + scheduler_config = asdict(vllm_scheduler_config) + # Override default values into original SchedulerConfig + scheduler_config["enable_chunked_prefill"] = False + scheduler_config["policy"] = "fcfs" + scheduler_config["num_scheduler_steps"] = 1 + scheduler_config[ + "scheduler_cls"] = "vllm_ascend.core.scheduler.AscendScheduler" + # Override params in original SchedulerConfig with params in additional_config.ascend_scheduler_config + for k, v in ascend_scheduler_config.items(): + scheduler_config[k] = v + # The "chunked_prefill_enabled" param of vllm's SchedulerConfig can't be initialized. + scheduler_config.pop("chunked_prefill_enabled") + return cls(**scheduler_config) + + def __post_init__(self) -> None: + self.chunked_prefill_enabled = self.enable_chunked_prefill + if self.policy != "fcfs": + raise NotImplementedError( + f"currently AscendScheduler only supports fcfs policy, got {self.policy}" + ) + if self.is_multimodal_model: + raise NotImplementedError( + "currently AscendScheduler only supports LLM modles.") + if self.num_scheduler_steps > 1: + raise NotImplementedError( + "currently AscendScheduler doesn't support multi-step.") + if self.send_delta_data: + raise NotImplementedError( + "currently AscendScheduler doesn't support send_delta_data.") + if self.delay_factor > 0: + raise NotImplementedError( + "currently AscendScheduler doesn't support scheduler_delay_factor." + ) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py new file mode 100644 index 0000000000..5f138fac5c --- /dev/null +++ b/vllm_ascend/core/scheduler.py @@ -0,0 +1,419 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import deque + +from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.core.scheduler import Scheduler +from vllm.v1.core.scheduler_output import NewRequestData, SchedulerOutput +from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class AscendScheduler(Scheduler): + """ This Scheduler extends vllm's original v1 scheduler + with prefill first scheduling strategy. """ + + def schedule(self) -> SchedulerOutput: + if self.scheduler_config.chunked_prefill_enabled: + return super().schedule() + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + req_to_new_block_ids: dict[str, list[int]] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # Record scheduled LoRA requests. + scheduled_loras: set[int] = set() + + # Use a temporary deque to collect requests that need to be skipped + # and put back at the head of the waiting queue later + skipped_waiting_requests: deque[Request] = deque() + + # Schedule prefill requests first. + while self.waiting and token_budget > 0: + if len(scheduled_new_reqs) == self.max_num_running_reqs: + break + + request = self.waiting[0] + + def skip_cur_request(): + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request and ( + len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras): + # Scheduling would exceed max_loras, skip. + skip_cur_request() + continue + + prompt_limit = self._get_prompt_limit(request) + # Get already-cached tokens. + computed_blocks, num_computed_tokens = self.kv_cache_manager.get_computed_blocks( + request) + num_new_tokens = request.num_prompt_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + max_tokens_in_kvcache = self.cache_config.num_gpu_blocks * self.block_size + prompt_limit = min(prompt_limit, max_tokens_in_kvcache) + + # Finish request that exceeds prompt_limit or kv cache size. + if num_new_tokens > prompt_limit: + logger.warning( + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", + num_new_tokens, + prompt_limit, + ) + request.status = RequestStatus.FINISHED_IGNORED + self.finished_req_ids.add(request.request_id) # type: ignore + self.waiting.popleft() + continue + + if num_new_tokens > token_budget: + # Scheduling would exceed token_budget, skip. + skip_cur_request() + continue + + if not self._check_watermark_for_prefill(num_new_tokens): + # Scheduling would exceed watermark, skip. + skip_cur_request() + continue + + assert num_new_tokens > 0 + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens, computed_blocks) + if new_blocks is None: + # The request cannot be scheduled. + break + + self.waiting.popleft() + self.running.append(request) # type: ignore + self.scheduled_req_ids.add(request.request_id) + # Check request status. + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in computed_blocks + new_blocks + ] + # Update request info. + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.extendleft(skipped_waiting_requests) + + # If no prefill requests are scheduled, + # Schedule decode requests next. + if len(self.scheduled_req_ids) == 0: + req_index = 0 + while req_index < len( + self.running) and token_budget > 0: # type: ignore + request = self.running[req_index] # type: ignore + if request.request_id in self.scheduled_req_ids: + # This request has already been scheduled. + req_index += 1 + continue + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens == 1 + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + preempted_req = self.running.pop() # type: ignore + self.kv_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + self.scheduled_req_ids.add(request.request_id) + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs # type: ignore + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len( + self.running) # type: ignore + ) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = 0 + if self.running: # type: ignore + any_request = self.running[0] # type: ignore + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) # type: ignore + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request(req, + req_to_new_block_ids[req.request_id]) + for req in scheduled_new_reqs + ] + resumed_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=True, + ) for req in scheduled_resumed_reqs + ] + running_reqs_data = [ + self._make_cached_request_data( + req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), + req_to_new_block_ids[req.request_id], + resumed_from_preemption=False, + ) for req in scheduled_running_reqs + ] + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, # type: ignore + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + ) + + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + self.requests[req_id].num_computed_tokens += num_scheduled_token + + self.finished_req_ids = set() # type: ignore + return scheduler_output + + def _check_watermark_for_prefill(self, num_new_tokens, watermark=0.01): + watermark_blocks = self.cache_config.num_gpu_blocks * watermark + num_required_blocks = cdiv(num_new_tokens, self.block_size) + if (self.kv_cache_manager.free_block_queue.num_free_blocks - + num_required_blocks) < watermark_blocks: + return False + return True + + def _get_prompt_limit(self, request: Request) -> int: + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if request.lora_request and request.lora_request.long_lora_max_len: + assert prompt_limit <= request.lora_request.long_lora_max_len + return request.lora_request.long_lora_max_len + else: + return prompt_limit + + def update_from_output( + self, + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> EngineCoreOutputs: + if self.scheduler_config.chunked_prefill_enabled: + return super().update_from_output(scheduler_output, + model_runner_output) + sampled_token_ids = model_runner_output.sampled_token_ids + spec_token_ids = model_runner_output.spec_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + spec_decoding_stats = None + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + + new_running: list[Request] = [] + outputs: list[EngineCoreOutput] = [] + + for request in self.running: # type: ignore + req_id = request.request_id + num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[req_index] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) + request.num_computed_tokens -= num_tokens_rejected + + if spec_decoding_stats is not None: + spec_decoding_stats.observe( + num_draft_tokens=len(scheduled_spec_token_ids), + num_accepted_tokens=len(generated_token_ids) - 1) + + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if cached_encoder_input_ids: + for input_id in list(cached_encoder_input_ids): + mm_positions = request.mm_positions[input_id] + start_pos = mm_positions["offset"] + num_tokens = mm_positions["length"] + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + # Add newly generated spec token ids to the request. + if spec_token_ids is not None: + request.spec_token_ids = spec_token_ids[req_index] + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = self._check_stop(request) + if stopped: + self._free_request(request) + del new_token_ids[num_new:] # Trim new tokens if needed. + break + + # Extract sample logprobs if needed. + if request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids: + # Add EngineCoreOutput for this Request. + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + stop_reason=request.stop_reason, + events=request.take_events())) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + self.scheduled_req_ids.remove(req_id) + if not stopped: + new_running.append(request) + + self.running = new_running + engine_core_outputs = EngineCoreOutputs( + outputs=outputs, + scheduler_stats=self.make_stats(), + ) + + return engine_core_outputs diff --git a/vllm_ascend/core/v1_engine_core_init.py b/vllm_ascend/core/v1_engine_core_init.py new file mode 100644 index 0000000000..bd558c2f31 --- /dev/null +++ b/vllm_ascend/core/v1_engine_core_init.py @@ -0,0 +1,75 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import queue +from typing import Type + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.engine.mm_input_cache import MMInputCacheServer +from vllm.v1.executor.abstract import Executor +from vllm.version import __version__ as VLLM_VERSION + +from vllm_ascend.core.scheduler import AscendScheduler + +logger = init_logger(__name__) + + +def engine_core_init_with_ascend_scheduler( + self, + vllm_config: VllmConfig, + executor_class: Type[Executor], + log_stats: bool, +): + assert vllm_config.model_config.runner_type != "pooling" + + logger.info("Initializing a V1 LLM engine (v%s) with config: %s", + VLLM_VERSION, vllm_config) + + self.log_stats = log_stats + + # Setup Model. + self.model_executor = executor_class(vllm_config) + + # Setup KV Caches and update CacheConfig after profiling. + num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(vllm_config) + vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks + vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + + # Setup scheduler. + self.scheduler = AscendScheduler( + scheduler_config=vllm_config.scheduler_config, + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + speculative_config=vllm_config.speculative_config, + log_stats=self.log_stats, + ) + + # Setup MM Input Mapper. + self.mm_input_cache_server = MMInputCacheServer(vllm_config.model_config) + + # Setup batch queue for pipeline parallelism. + # Batch queue for scheduled batches. This enables us to asynchronously + # schedule and execute batches, and is required by pipeline parallelism + # to eliminate pipeline bubbles. + self.batch_queue_size = self.model_executor.max_concurrent_batches + self.batch_queue = None + if self.batch_queue_size > 1: + logger.info("Batch queue is enabled with size %d", + self.batch_queue_size) + self.batch_queue = queue.Queue(self.batch_queue_size) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index e2fe3d8998..c38059313f 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -144,6 +144,25 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: ) cache_config.enable_prefix_caching = False + if envs.VLLM_USE_V1: + # Activate custom ops for v1. + vllm_config.compilation_config.custom_ops = ["all"] + additional_config = vllm_config.additional_config + if additional_config and additional_config.get( + "ascend_scheduler_config", None) is not None: + additional_scheduler_config = additional_config.get( + "ascend_scheduler_config") + from vllm_ascend.core.schedule_config import \ + AscendSchedulerConfig + ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config( + vllm_config.scheduler_config, additional_scheduler_config) + vllm_config.scheduler_config = ascend_scheduler_config + from vllm.v1.engine.core import EngineCore + + from vllm_ascend.core.v1_engine_core_init import \ + engine_core_init_with_ascend_scheduler + EngineCore.__init__ = engine_core_init_with_ascend_scheduler + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 68c9032a40..1c721462b0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -25,7 +25,7 @@ import torch import torch.distributed import torch.nn as nn -from vllm.attention import AttentionType +from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group @@ -49,14 +49,13 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, +from vllm_ascend.attention.attention import AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) if TYPE_CHECKING: from vllm.v1.core.scheduler_output import SchedulerOutput -NPU_PAGED_ATTENTION_MASK_VALUE = -10000 - logger = init_logger(__name__) @@ -107,6 +106,24 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() + self.attn_backend = get_attn_backend( + self.head_size, + self.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + if self.attn_backend is None: + error_msg = ( + f"Error with get_att_backend: {self.head_size=}, " + f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{self.model_config.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 NPUModelRunner.") + # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY @@ -240,13 +257,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # the size of the pre-constructed mask matrix based on requirements. mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) self.attn_mask_len = min(self.max_model_len, int(mask_len)) - self.attn_mask_npu = torch.full( - (self.attn_mask_len, self.attn_mask_len), - NPU_PAGED_ATTENTION_MASK_VALUE, - device=self.device, - dtype=self.vllm_config.model_config.dtype) - self.attn_mask_npu.masked_fill_( - self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0) + self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( + self.attn_mask_len, self.dtype) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -403,35 +415,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def get_model(self) -> nn.Module: return self.model - def make_attention_mask(self, seq_lens, query_lens, - position) -> torch.Tensor: - max_seq_len = max(seq_lens, default=0) - if max_seq_len <= self.attn_mask_len: - return torch.index_select(self.attn_mask_npu, - dim=0, - index=position)[:, :max_seq_len] - - total_q_len = sum(query_lens) - attn_mask = torch.zeros((total_q_len, max_seq_len), - dtype=self.vllm_config.model_config.dtype, - device="cpu") - - current_row = 0 - for i in range(len(query_lens)): - seq_len = seq_lens[i] - q_len = query_lens[i] - context_len = seq_len - q_len - - assert context_len >= 0 - attn_mask[current_row:current_row + q_len, - context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE - right_tensor = attn_mask[current_row:current_row + q_len, - context_len:seq_len] - right_tensor.mask_fill_( - right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0) - current_row += q_len - - return attn_mask.to(self.device, non_blocking=True) + def make_attention_mask(self, seq_lens, query_lens, position, + attn_state) -> torch.Tensor: + # Chunk Prefill situation. + if attn_state == AscendAttentionState.ChunkedPrefill: + return self.attn_mask_builder.get_splitfuse_attn_mask( + seq_lens, query_lens, position, self.dtype, self.device) + # Prefill-only situation. + elif attn_state == AscendAttentionState.PrefillOnly: + max_seq_len = max(seq_lens, default=0) + return self.attn_mask_builder.get_attn_mask( + max_seq_len, self.dtype, self.device) + # Decode-only situation. + else: + return None def _process_reqs( self, @@ -465,6 +462,9 @@ def _process_reqs( cu_num_tokens = np.cumsum(num_scheduled_tokens) cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens) + sample_indices = cu_num_tokens - 1 + sample_indices = torch.from_numpy(sample_indices).to(self.device, + non_blocking=True) arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets positions_np = self.positions_np[:total_num_scheduled_tokens] @@ -494,9 +494,18 @@ def _process_reqs( slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( self.device, non_blocking=True) + attn_state = AscendAttentionState.ChunkedPrefill + if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): + attn_state = AscendAttentionState.PrefillOnly + elif np.all(num_scheduled_tokens == 1): + attn_state = AscendAttentionState.DecodeOnly + else: + attn_state = AscendAttentionState.ChunkedPrefill + attn_mask = self.make_attention_mask(seq_lens=seq_lens, query_lens=num_scheduled_tokens, - position=positions) + position=positions, + attn_state=attn_state) attn_metadata = AscendMetadata( seq_lens=query_lens, @@ -505,6 +514,7 @@ def _process_reqs( block_tables=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), attn_mask=attn_mask, + attn_state=attn_state, ) # Prepare input_ids @@ -531,7 +541,7 @@ def _process_reqs( attn_metadata=attn_metadata, ) - return hidden_states[cu_num_tokens - 1] + return hidden_states[sample_indices] @torch.inference_mode() def execute_model( @@ -809,6 +819,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + import torch_npu if len(kv_cache_config.groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " @@ -821,13 +832,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: assert tensor_config.size % layer_spec.page_size_bytes == 0 num_blocks = tensor_config.size // layer_spec.page_size_bytes if isinstance(layer_spec, FullAttentionSpec): - kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape( + kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, layer_spec.head_size) dtype = layer_spec.dtype kv_caches[layer_name] = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) + torch_npu.npu_format_cast(kv_caches[layer_name], 2) else: raise NotImplementedError diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index ba1d984880..5d3fb145f4 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -40,6 +40,7 @@ from vllm.v1.worker.worker_base import WorkerBase from vllm_ascend.device_allocator.camem import CaMemAllocator +from vllm_ascend.utils import try_register_lib from vllm_ascend.worker.model_runner_v1 import NPUModelRunner logger = init_logger(__name__) @@ -74,6 +75,12 @@ def __init__(self, self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + # Try to import mindie_turbo to accelerate vLLM inference. + try_register_lib( + "mindie_turbo", + "MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo." + ) + if self.cache_config.cache_dtype == "auto": self.cache_dtype = self.model_config.dtype else: