diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index b5022355be..99afe503ba 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -112,7 +112,13 @@ jobs: pytest -sv tests/singlecard/test_scheduler.py # guided decoding doesn't work, fix it later # pytest -sv tests/singlecard/test_guided_decoding.py.py - pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py + # test_ascend_config.py should be ran separately because it will regenerate the global config many times. + pytest -sv tests/singlecard/test_ascend_config.py + pytest -sv tests/singlecard/ \ + --ignore=tests/singlecard/test_offline_inference.py \ + --ignore=tests/singlecard/test_scheduler.py \ + --ignore=tests/singlecard/test_guided_decoding.py \ + --ignore=tests/singlecard/test_ascend_config.py else pytest -sv tests/multicard/test_ilama_lora_tp2.py VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py @@ -128,11 +134,14 @@ jobs: # guided decoding doesn't work, fix it later # pytest -sv tests/singlecard/test_guided_decoding.py.py pytest -sv tests/singlecard/test_camem.py + # test_ascend_config.py should be ran separately because it will regenerate the global config many times. + pytest -sv tests/singlecard/test_ascend_config.py pytest -sv tests/singlecard/ \ --ignore=tests/singlecard/test_offline_inference.py \ --ignore=tests/singlecard/test_scheduler.py \ --ignore=tests/singlecard/test_guided_decoding.py \ - --ignore=tests/singlecard/test_camem.py + --ignore=tests/singlecard/test_camem.py \ + --ignore=tests/singlecard/test_ascend_config.py else pytest -sv tests/multicard/test_ilama_lora_tp2.py # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error. diff --git a/docs/source/index.md b/docs/source/index.md index b11222efc6..c84e1d80b3 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -46,6 +46,7 @@ faqs user_guide/suppoted_features user_guide/supported_models user_guide/env_vars +user_guide/additional_config user_guide/release_notes ::: diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md new file mode 100644 index 0000000000..d2d4234d77 --- /dev/null +++ b/docs/source/user_guide/additional_config.md @@ -0,0 +1,70 @@ +# Additional Configuration + +addintional configuration is a mechanism provided by vLLM to allow plugins to control inner behavior by their own. vLLM Ascend uses this mechanism to make the project more flexible. + +## How to use + +With either online mode or offline mode, users can use additional configuration. Take Qwen3 as an example: + +**Online mode**: + +```bash +vllm serve Qwen/Qwen3-8B --additional-config='{"config_key":"config_value"}' +``` + +**Offline mode**: + +```python +from vllm import LLM + +LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"}) +``` + +### Configuration options + +The following table lists the additional configuration options available in vLLM Ascend: + +| Name | Type | Default | Description | +| ---- | ---- | ------- | ----------- | +| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | +| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | +| `expert_tensor_parallel_size` | str | `1` | Expert tensor parallel size the model to use. | + +The details of each config option are as follows: + +**torchair_graph_config** + +| Name | Type | Default | Description | +| ---- | ---- | ------- | ----------- | +| `enabled` | bool | `False` | Whether to enable torchair graph mode | +| `use_cached_graph` | bool | `False` | Whether to use cached graph | +| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | +| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | + +**ascend_scheduler_config** + +| Name | Type | Default | Description | +| ---- | ---- | ------- | ----------- | +| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine| + +ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `chunked_prefill_enabled: true` to ascend_scheduler_config as well. + +### Example + +A full example of additional configuration is as follows: + +``` +{ + "torchair_graph_config": { + "enabled": true, + "use_cached_graph": true, + "graph_batch_sizes": [1, 2, 4, 8], + "graph_batch_sizes_init": true + }, + "ascend_scheduler_config": { + "enabled": true, + "chunked_prefill_enabled": true, + }, + "expert_tensor_parallel_size": 1 +} +``` diff --git a/examples/dp_offline/data_parallel.py b/examples/dp_offline/data_parallel.py index 0299497910..b06c52d8c5 100644 --- a/examples/dp_offline/data_parallel.py +++ b/examples/dp_offline/data_parallel.py @@ -62,7 +62,9 @@ def main(): max_num_seqs=num_seqs, additional_config={ 'expert_tensor_parallel_size': etp_size, - 'enable_graph_mode': False, + 'torchair_graph_config': { + 'enabled': False, + }, }) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/long_term/spec_decode/e2e/conftest.py b/tests/long_term/spec_decode/e2e/conftest.py index 67a46e114c..c8c9546400 100644 --- a/tests/long_term/spec_decode/e2e/conftest.py +++ b/tests/long_term/spec_decode/e2e/conftest.py @@ -167,17 +167,17 @@ def run_equality_correctness_test( # TODO current torchair graph mode needs clean torchair cache. # if do not clean, it will raise error - additional_config = common_llm_kwargs.get("additional_config") - enable_graph_mode = additional_config.get( - "enable_graph_mode") if additional_config else False + torchair_graph_enabled = common_llm_kwargs.get( + "additional_config", {}).get("torchair_graph_config", + {}).get("enabled", False) with vllm_runner(**org_args) as vllm_model: - if enable_graph_mode: + if torchair_graph_enabled: _clean_torchair_cache() org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) with vllm_runner(**sd_args) as vllm_model: - if enable_graph_mode: + if torchair_graph_enabled: _clean_torchair_cache() if ensure_all_accepted or expected_acceptance_rate is not None: # Force log interval to be 0 to catch all metrics. diff --git a/tests/long_term/spec_decode/e2e/test_mtp_correctness.py b/tests/long_term/spec_decode/e2e/test_mtp_correctness.py index 0a994ed15d..8d7f2612aa 100644 --- a/tests/long_term/spec_decode/e2e/test_mtp_correctness.py +++ b/tests/long_term/spec_decode/e2e/test_mtp_correctness.py @@ -218,7 +218,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, "common_llm_kwargs", [{ "additional_config": { - 'enable_graph_mode': True, + 'torchair_graph_config': { + "enabled": True, + }, }, # Print spec metrics. @@ -262,7 +264,9 @@ def test_mtp_e2e_greedy_correctness_torchair_graph( "common_llm_kwargs", [{ "additional_config": { - 'enable_graph_mode': True, + 'torchair_graph_config': { + "enabled": True, + }, }, # Print spec metrics. diff --git a/tests/multicard/test_dynamic_npugraph_batchsize.py b/tests/multicard/test_dynamic_npugraph_batchsize.py index 1424cb9d05..e5c7042b1e 100644 --- a/tests/multicard/test_dynamic_npugraph_batchsize.py +++ b/tests/multicard/test_dynamic_npugraph_batchsize.py @@ -18,8 +18,6 @@ import torch from vllm import LLM, SamplingParams -from vllm_ascend.utils import vllm_version_is - MODELS = [ "Qwen/Qwen2.5-0.5B-Instruct", ] @@ -32,9 +30,6 @@ ] -@pytest.mark.skipif( - (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")), - reason="aclgraph not supported in v0.8.5 and v0.8.5.post1") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tp_size", TENSOR_PARALLELS) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index 941055cf72..c60954ba58 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -31,9 +31,7 @@ def test_models_distributed_QwQ(): example_prompts = [ - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", - "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", - "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + "Hello, my name is", ] dtype = "half" max_tokens = 5 @@ -48,9 +46,7 @@ def test_models_distributed_QwQ(): def test_models_distributed_DeepSeek(): example_prompts = [ - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", - "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", - "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + "Hello, my name is", ] dtype = "half" max_tokens = 5 diff --git a/tests/compile/test_aclgraph.py b/tests/singlecard/test_aclgraph.py similarity index 90% rename from tests/compile/test_aclgraph.py rename to tests/singlecard/test_aclgraph.py index fad20eb884..e0bfb65cf8 100644 --- a/tests/compile/test_aclgraph.py +++ b/tests/singlecard/test_aclgraph.py @@ -28,16 +28,12 @@ from tests.conftest import VllmRunner from tests.model_utils import check_outputs_equal -from vllm_ascend.utils import vllm_version_is MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", reason="aclgraph only support on v1") -@pytest.mark.skipif( - (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")), - reason="aclgraph not supported in v0.8.5 and v0.8.5.post1") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) def test_models( @@ -88,9 +84,6 @@ def test_models( @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", reason="aclgraph only support on v1") -@pytest.mark.skipif( - (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")), - reason="aclgraph not supported in v0.8.5 and v0.8.5.post1") def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None: with monkeypatch.context() as m: m.setenv("VLLM_USE_MODELSCOPE", "True") diff --git a/tests/singlecard/test_ascend_config.py b/tests/singlecard/test_ascend_config.py new file mode 100644 index 0000000000..2642c0eac0 --- /dev/null +++ b/tests/singlecard/test_ascend_config.py @@ -0,0 +1,118 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from tests.conftest import VllmRunner +from vllm_ascend.ascend_config import clear_ascend_config, get_ascend_config + + +def _clean_up_ascend_config(func): + + def wrapper(*args, **kwargs): + clear_ascend_config() + func(*args, **kwargs) + clear_ascend_config() + + return wrapper + + +@_clean_up_ascend_config +def test_run_without_ascend_config(): + with VllmRunner("facebook/opt-125m"): + ascend_config = get_ascend_config() + + assert not ascend_config.torchair_graph_config.enabled + assert not ascend_config.torchair_graph_config.use_cached_graph + assert ascend_config.torchair_graph_config.graph_batch_sizes == [] + assert not ascend_config.torchair_graph_config.graph_batch_sizes_init + assert not ascend_config.ascend_scheduler_config.enabled + assert ascend_config.expert_tensor_parallel_size == 1 + + +@_clean_up_ascend_config +def test_run_with_ascend_config(): + input_additional_config = { + "torchair_graph_config": { + # torchair graph only works with deepseek. The e2e test should be added + # in multicard test with deepseek models. + "enabled": False, + "use_cached_graph": True, + "graph_batch_sizes": [1, 2, 4, 8], + "graph_batch_sizes_init": False, + }, + "ascend_scheduler_config": { + "enabled": True, + "enable_chunked_prefill": True, + }, + "expert_tensor_parallel_size": 1 + } + with VllmRunner("facebook/opt-125m", + additional_config=input_additional_config): + ascend_config = get_ascend_config() + + assert not ascend_config.torchair_graph_config.enabled + assert ascend_config.torchair_graph_config.use_cached_graph + assert ascend_config.torchair_graph_config.graph_batch_sizes == [ + 1, 2, 4, 8 + ] + assert not ascend_config.torchair_graph_config.graph_batch_sizes_init + assert ascend_config.ascend_scheduler_config.enabled + assert ascend_config.ascend_scheduler_config.enable_chunked_prefill + assert ascend_config.expert_tensor_parallel_size == 1 + + +@_clean_up_ascend_config +def test_ascend_config_init_error(): + # ascend_config should be initialized first + with pytest.raises(RuntimeError): + _ = get_ascend_config() + + +@_clean_up_ascend_config +def test_ascend_config_load_error(): + # graph_batch_sizes should be list. + with pytest.raises(TypeError): + input_additional_config_fake_1 = { + "torchair_graph_config": { + "graph_batch_sizes": "fake_size", + }, + } + with VllmRunner("facebook/opt-125m", + additional_config=input_additional_config_fake_1): + pass + + # graph_batch_sizes_init should not be True when graph_batch_sizes is not empty. + with pytest.raises(ValueError): + input_additional_config_fake_2 = { + "torchair_graph_config": { + "graph_batch_sizes": [1, 2, 4, 8], + "graph_batch_sizes_init": True, + }, + } + with VllmRunner("facebook/opt-125m", + additional_config=input_additional_config_fake_2): + pass + + # torchair graph only works with deepseek. + with pytest.raises(NotImplementedError): + input_additional_config_fake_2 = { + "torchair_graph_config": { + "enabled": True, + }, + } + with VllmRunner("facebook/opt-125m", + additional_config=input_additional_config_fake_2): + pass diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py new file mode 100644 index 0000000000..2463f17591 --- /dev/null +++ b/vllm_ascend/ascend_config.py @@ -0,0 +1,138 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import vllm.envs as envs +from vllm.logger import logger + + +class AscendConfig: + """ + Configuration Object for additional_config from vllm.configs. + """ + + def __init__(self, vllm_config): + additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} + + torchair_graph_config = additional_config.get("torchair_graph_config", + {}) + self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config) + + ascend_scheduler_config = additional_config.get( + "ascend_scheduler_config", {}) + self.ascend_scheduler_config = AscendSchedulerConfig( + ascend_scheduler_config) + + self.expert_tensor_parallel_size = int( + additional_config.get("expert_tensor_parallel_size", 1)) + + +class TorchairGraphConfig: + """ + Configuration Object for torchair_graph_config from additional_config + """ + + def __init__(self, torchair_graph_config): + self.enabled = torchair_graph_config.get("enabled", False) + self.use_cached_graph = torchair_graph_config.get( + "use_cached_graph", False) + self.graph_batch_sizes = torchair_graph_config.get( + "graph_batch_sizes", []) + self.graph_batch_sizes_init = torchair_graph_config.get( + "graph_batch_sizes_init", False) + + if not isinstance(self.graph_batch_sizes, list): + raise TypeError("graph_batch_sizes must be list[int]") + if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0: + raise ValueError( + "graph_batch_sizes_init is only valid when graph_batch_sizes is empty" + ) + + +class AscendSchedulerConfig: + """ + Configuration Object for ascend_scheduler_config from additional_config + """ + + def __init__(self, ascend_scheduler_config: dict): + self.enabled = ascend_scheduler_config.get("enabled", False) + # Ascend scheduler is based on vllm v0 scheduler, so we should support + # all vllm v0 scheduler configs as well. + for k, v in ascend_scheduler_config.items(): + if not hasattr(self, k): + setattr(self, k, v) + + +_ASCEND_CONFIG: Optional[AscendConfig] = None + + +def init_ascend_config(vllm_config): + global _ASCEND_CONFIG + if _ASCEND_CONFIG is not None: + return _ASCEND_CONFIG + _ASCEND_CONFIG = AscendConfig(vllm_config) + return _ASCEND_CONFIG + + +def clear_ascend_config(): + global _ASCEND_CONFIG + _ASCEND_CONFIG = None + + +def get_ascend_config(): + global _ASCEND_CONFIG + if _ASCEND_CONFIG is None: + raise RuntimeError( + "Ascend config is not initialized. Please call init_ascend_config first." + ) + return _ASCEND_CONFIG + + +def check_ascend_config(vllm_config, enforce_eager): + ascend_config = get_ascend_config() + + # Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode. + if ascend_config.torchair_graph_config.enabled and not enforce_eager: + raise RuntimeError( + "Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode." + ) + + # torchair_graph only work with deepseek model and mla enabled. + if ascend_config.torchair_graph_config.enabled: + if envs.VLLM_MLA_DISABLE: + logger.warning( + "Torchair graph mode is still experimental and not supported for V1 without mla currently, " + "it has been disabled automatically.") + ascend_config.ascend_scheduler_config.enabled = False + if vllm_config.model_config: + model_type = vllm_config.model_config.hf_config.model_type + if "deepseek" not in model_type: + raise NotImplementedError( + "Torchair graph mode only works with deepseek model.") + + # for V1 Engine, aclgraph doesn't work with deepseek model and only qwen model is well tested. + if envs.VLLM_USE_V1 and vllm_config.model_config is not None and not enforce_eager: + model_type = vllm_config.model_config.hf_config.model_type + if "deepseek" in model_type: + raise NotImplementedError( + "ACL Graph does not support deepseek. Please " + "try torchair graph mode to serve deepseek models on vllm-ascend." + " Or set `enforce_eager=True` to use eager mode.") + if "qwen" not in model_type: + logger.warning( + "ACL Graph is currently experimental. Please " + "raise an issue on https://github.com/vllm-project/vllm-ascend/issues" + " if you encourage any Error") diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 48a01183a1..8f130e4241 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -32,9 +32,9 @@ compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.config import get_current_vllm_config from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.cache import concat_and_cache_mla from vllm_ascend.platform import CUSTOM_OP_ENABLED from vllm_ascend.worker.model_runner import ( @@ -1002,11 +1002,8 @@ def __init__( self.w_kc = None self.w_vc = None - self.enable_graph_mode = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled def exec_kv( self, @@ -1179,7 +1176,7 @@ def forward( self.num_heads, -1) # TODO: Replace the env with more flexible expressions - if self.enable_graph_mode: + if self.torchair_graph_enabled: if len(kv_cache) > 0 and kv_cache[0].numel( ) > 0 and attn_metadata.num_prefills > 0: slots = attn_metadata.slot_mapping @@ -1230,7 +1227,7 @@ def forward( ) elif attn_metadata.decode_metadata: assert kv_cache is not None - if self.enable_graph_mode: + if self.torchair_graph_enabled: # shape of query for npu graph mode should be: # [bs, num_heads_per_rank, seq_len, dim] q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 4a5410fa66..ae3dd6205b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -8,10 +8,10 @@ AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla @@ -443,20 +443,8 @@ def __init__( self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) - # Handle the differences between the flash_attn_varlen from flash_attn - # and the one from vllm_flash_attn. The former is used on RoCM and the - # latter has an additional parameter to control FA2 vs FA3 - # self.flash_attn_varlen_func = flash_attn_varlen_func - # if self.vllm_flash_attn_version is not None: - # self.flash_attn_varlen_func = \ - # functools.partial(flash_attn_varlen_func, - # fa_version=self.vllm_flash_attn_version) - - self.enable_graph_mode = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) @@ -713,7 +701,7 @@ def forward( if attn_metadata is None: # Profiling run. return output - self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly + self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: kv_c, k_pe = self.kv_a_proj_with_mqa( @@ -776,7 +764,7 @@ def forward( .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] - if self.enable_graph_mode: + if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) @@ -801,7 +789,7 @@ def forward( prefill_q_pe.contiguous(), prefill_k_pe, max_seq_len=attn_metadata.prefill.max_seq_lens) - if self.enable_graph_mode: + if self.torchair_graph_enabled: if len(kv_cache) > 0 and kv_cache[0].numel( ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: slots = attn_metadata.slot_mapping diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 51e49606a6..4a4131ecd7 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -33,7 +33,7 @@ class AscendSchedulerConfig(SchedulerConfig): def initialize_from_config( cls, vllm_scheduler_config: SchedulerConfig, - ascend_scheduler_config: dict, + ascend_scheduler_config, ): scheduler_config = { field.name: getattr(vllm_scheduler_config, field.name) @@ -45,9 +45,10 @@ def initialize_from_config( scheduler_config["num_scheduler_steps"] = 1 scheduler_config["scheduler_cls"] = ( "vllm_ascend.core.scheduler.AscendScheduler") - # Override params in original SchedulerConfig with params in additional_config.ascend_scheduler_config - for k, v in ascend_scheduler_config.items(): - scheduler_config[k] = v + # Override params in original SchedulerConfig with params in ascend_scheduler_config + for k, _ in scheduler_config.items(): + if hasattr(ascend_scheduler_config, k): + scheduler_config[k] = getattr(ascend_scheduler_config, k) return cls(**scheduler_config) def __post_init__(self) -> None: diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 515ebe1a8e..34894b3697 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -34,8 +34,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) @@ -67,6 +66,7 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor @@ -214,11 +214,8 @@ def __init__( self.params_dtype = torch.get_default_dtype() - self.enable_graph_mode = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled def forward( self, @@ -248,7 +245,7 @@ def forward( if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: chunks = torch.chunk(hidden_states, self.tp_size, dim=0) hidden_states = chunks[self.tp_rank] - elif not self.enable_graph_mode: + elif not self.torchair_graph_enabled: num_padding_tokens = (self.tp_size - num_tokens % self.tp_size) % self.tp_size # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C @@ -272,7 +269,7 @@ def forward( ) * self.routed_scaling_factor if self.tp_size > 1: - if self.enable_graph_mode: + if self.torchair_graph_enabled: if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: final_hidden_states = torch.zeros( [num_tokens, hidden_size], @@ -423,11 +420,9 @@ def __init__( self.prefix = prefix self.debug_layer_idx = int(self.prefix.split(".")[-2]) - self.enable_graph_mode = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled def forward( self, @@ -440,7 +435,7 @@ def forward( hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states - if self.enable_graph_mode: + if self.torchair_graph_enabled: forward_kwargs = {} if envs.VLLM_USE_V1: output_shape = hidden_states.shape diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 4853b27282..f6b5385379 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -32,6 +32,7 @@ QuantizationConfig import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 @@ -587,11 +588,8 @@ def __init__(self, moe: MoEConfig = None): self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.local_batch_size = self.global_batch_size // self.ep_size - self.enable_graph_mode = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled try: device_group = ep_group.device_group @@ -678,7 +676,7 @@ def apply( top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) - elif self.enable_graph_mode or get_ep_group().world_size == 1: + elif self.torchair_graph_enabled or get_ep_group().world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -772,11 +770,8 @@ def __init__( self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group - self.enable_graph_mode = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -818,12 +813,6 @@ def __init__( self.ep_group = get_ep_group() self.quant_method.create_weights(layer=self, **moe_quant_params) - self.enable_graph_mode = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -844,13 +833,13 @@ def forward(self, if self.dp_size > 1: if VLLM_ENABLE_MC2 and not is_prefill: ... - elif self.enable_graph_mode: + elif self.torchair_graph_enabled: if USING_LCCL_COM: # type: ignore hidden_states = get_dp_group().all_gather( hidden_states, 0, False) router_logits = get_dp_group().all_gather( router_logits, 0, False) - elif self.enable_graph_mode and not is_prefill: + elif self.torchair_graph_enabled and not is_prefill: hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits = get_dp_group().all_gather(router_logits, 0) else: @@ -878,14 +867,14 @@ def forward(self, if self.dp_size > 1: if VLLM_ENABLE_MC2 and not is_prefill: ... - elif self.enable_graph_mode: + elif self.torchair_graph_enabled: if USING_LCCL_COM: # type: ignore hidden_states = dist._functional_collectives.reduce_scatter_tensor( hidden_states, "sum", scatter_dim=0, group=get_dp_group().device_group) - elif self.enable_graph_mode and not is_prefill: + elif self.torchair_graph_enabled and not is_prefill: hidden_states = dist._functional_collectives.reduce_scatter_tensor( hidden_states, "sum", diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 413ba6f5c8..647fefbe0e 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -24,6 +24,7 @@ from vllm.logger import logger from vllm.platforms import Platform, PlatformEnum +from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes CUSTOM_OP_ENABLED = False @@ -117,10 +118,12 @@ def mem_get_info(cls) -> Tuple[int, int]: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + # initialize ascend config from vllm additional_config + ascend_config = init_ascend_config(vllm_config) + from vllm.config import CompilationLevel # noqa: E402 compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config - additional_config = vllm_config.additional_config parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config @@ -130,11 +133,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # NOTE: When enable_expert_parallel is True, we follow vLLM convention: # ep_size = world_size, which means expert_tensor_parallel_size must be 1 - if (additional_config - and "expert_tensor_parallel_size" in additional_config - and not parallel_config.enable_expert_parallel): - parallel_config.expert_tensor_parallel_size = int( - additional_config["expert_tensor_parallel_size"]) + if ascend_config.expert_tensor_parallel_size > 1 and not parallel_config.enable_expert_parallel: + parallel_config.expert_tensor_parallel_size = ascend_config.expert_tensor_parallel_size # Calculate expert parallel size based on world size parallel_config.expert_parallel_size = ( @@ -148,41 +148,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: enforce_eager = getattr(model_config, "enforce_eager", False) - if additional_config is not None: - enable_graph_mode = additional_config.get("enable_graph_mode", - False) - if enable_graph_mode: - if enforce_eager: - raise RuntimeError( - "Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode." - ) - elif envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE: - logger.warning( - "NPU graph mode is still experimental and not supported for V1 without mla currently, " - "it has been disabled automatically.") - additional_config["enable_graph_mode"] = False - if model_config: - model_type = model_config.hf_config.model_type - if "deepseek" not in model_type: - raise NotImplementedError( - "enable_graph_mode only works with deepseek model." - ) - # Set compilation level to NO_COMPILATION to disable ACL Graph - compilation_config.level = CompilationLevel.NO_COMPILATION - - elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager: - model_type = model_config.hf_config.model_type - if "deepseek" in model_type: - raise NotImplementedError( - "ACL Graph does not support deepseek. Please " - "adopt additional_config={'enable_graph_mode': True} " - "to serve deepseek models with NPU graph mode on vllm-ascend with V1 engine." - " Or set `enforce_eager=True` to use eager mode.") - elif "qwen" not in model_type: - logger.warning( - "ACL Graph is currently experimental. Please " - "raise an issue on https://github.com/vllm-project/vllm-ascend/issues" - " if you encourage any Error") + check_ascend_config(vllm_config, enforce_eager) if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION: logger.info("Compilation disabled, using eager mode by default") @@ -192,6 +158,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "NPU does not support %s compilation level. Setting level to NO_COMPILATION", compilation_config.level) compilation_config.level = CompilationLevel.NO_COMPILATION + elif ascend_config.torchair_graph_config.enabled: + logger.info( + "Torchair compilation enabled on NPU. Setting level to NO_COMPILATION" + ) + compilation_config.level = CompilationLevel.NO_COMPILATION else: logger.info( "PIECEWISE compilation enabled on NPU. use_inductor not supported - " @@ -224,17 +195,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if envs.VLLM_USE_V1: # Activate custom ops for v1. compilation_config.custom_ops = ["all"] - # If ascend_scheduler_config exists in additional_config, - # extents original scheduler_config to use AscendScheduler. - if additional_config and additional_config.get( - "ascend_scheduler_config", None) is not None: - additional_scheduler_config = additional_config.get( - "ascend_scheduler_config") + # If ascend_scheduler_config is enabled, + # extents original scheduler_config to use AscendScheduler. + if ascend_config.ascend_scheduler_config.enabled: from vllm_ascend.core.schedule_config import \ AscendSchedulerConfig ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config( - vllm_config.scheduler_config, additional_scheduler_config) + vllm_config.scheduler_config, + ascend_config.ascend_scheduler_config) vllm_config.scheduler_config = ascend_scheduler_config @classmethod diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 45580c41dd..9d651fbfc8 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -20,10 +20,10 @@ import torch import torch.distributed as dist import torch_npu -from vllm.config import get_current_vllm_config from vllm.distributed import GroupCoordinator import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.utils import dispose_tensor @@ -509,11 +509,8 @@ def __init__(self): self.ep_group = get_ep_group() - self.enable_graph_mode = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled try: device_group = self.ep_group.device_group @@ -638,7 +635,7 @@ def apply( top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) - elif self.enable_graph_mode or self.ep_group.world_size == 1: + elif self.torchair_graph_enabled or self.ep_group.world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, diff --git a/vllm_ascend/worker/cache_engine.py b/vllm_ascend/worker/cache_engine.py index 72de201f1d..d8d9087745 100644 --- a/vllm_ascend/worker/cache_engine.py +++ b/vllm_ascend/worker/cache_engine.py @@ -20,10 +20,11 @@ from typing import Any, List import torch -from vllm.config import get_current_vllm_config from vllm.utils import is_pin_memory_available from vllm.worker.cache_engine import CacheEngine +from vllm_ascend.ascend_config import get_ascend_config + def allocate_kv_cache( self, @@ -36,8 +37,8 @@ def allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[Any] = [] - additional_config = get_current_vllm_config().additional_config - if additional_config and additional_config.get("enable_graph_mode", False): + ascend_config = get_ascend_config() + if ascend_config.torchair_graph_config.enabled: # Align entries so they are 256 byte aligned for better performance # Primarily targets MLA as this typically only ends up having entries # be 128 byte aligned. diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index a0b6a0930f..2f3b872f0b 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -64,6 +64,8 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm_ascend.ascend_config import get_ascend_config + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -540,7 +542,7 @@ def build(self) -> ModelInputForNPU: } # Add graph_pad_size here - if self.runner.enable_graph_mode: + if self.runner.torchair_graph_enabled: graph_pad_size = self.runner.scheduler_config.max_num_seqs - len( seq_lens) else: @@ -603,7 +605,7 @@ def build(self) -> ModelInputForNPU: ] multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - if self.runner.enable_graph_mode: + if self.runner.torchair_graph_enabled: torch._dynamo.mark_static(input_tokens_tensor) torch._dynamo.mark_static(input_positions_tensor) torch._dynamo.mark_static(attn_metadata.block_tables) @@ -864,14 +866,9 @@ def __init__( self.max_batchsize_to_capture = \ self.vllm_config.compilation_config.max_capture_size - self.enable_graph_mode = False - self.use_cached_npu_graph = False - additional_config = vllm_config.additional_config - if additional_config: - self.enable_graph_mode = additional_config.get( - "enable_graph_mode", False) - self.use_cached_npu_graph = additional_config.get( - "use_cached_npu_graph", False) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph self.has_inner_state = model_config.has_inner_state @@ -971,7 +968,7 @@ def load_model(self) -> None: self.model = self.lora_manager.create_lora_manager(self.model) # adapter torch compile with npu_backend - if self.enable_graph_mode: + if self.torchair_graph_enabled: import torchair # type: ignore from torchair import patch_for_hcom # type: ignore @@ -1290,7 +1287,7 @@ def execute_model( assert model_input.attn_metadata is not None # TODO(zzzzwwjj): Do we need to do it every time? - if self.enable_graph_mode: + if self.torchair_graph_enabled: torch._dynamo.mark_static(model_input.input_tokens) torch._dynamo.mark_static(model_input.input_positions) torch._dynamo.mark_static(model_input.attn_metadata.block_tables) @@ -1305,7 +1302,7 @@ def execute_model( virtual_engine = model_input.virtual_engine prefill_meta = model_input.attn_metadata.prefill_metadata previous_hidden_states = kwargs.get("previous_hidden_states") - if prefill_meta is None and self.enable_graph_mode: + if prefill_meta is None and self.torchair_graph_enabled: model_executable = self.compile_model # Note: graph_batch_size value not same as GPU graph_batch_size = model_input.input_tokens.shape[ # type: ignore @@ -1359,7 +1356,7 @@ def execute_model( "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} - if self.enable_graph_mode: + if self.torchair_graph_enabled: model_kwargs: Dict[str, Any] = {"inputs_embeds": None} else: model_kwargs = {} @@ -1377,7 +1374,7 @@ def execute_model( self.vllm_config, virtual_engine): if model_input.attn_metadata is not None: model_input.attn_metadata.input_positions = model_input.input_positions - if self.enable_graph_mode: + if self.torchair_graph_enabled: model_kwargs["kv_caches"] = kv_caches model_kwargs["attn_metadata"] = model_input.attn_metadata hidden_or_intermediate_states = model_executable( @@ -1461,7 +1458,7 @@ def execute_model( hidden_states = hidden_or_intermediate_states.index_select( 0, indices) output.prefill_hidden_states = hidden_or_intermediate_states - elif self.enable_graph_mode: + elif self.torchair_graph_enabled: hidden_states = hidden_or_intermediate_states[:len(indices)] else: hidden_states = hidden_or_intermediate_states diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f64336dc35..58c73507a0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -61,6 +61,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata @@ -137,13 +138,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs - additional_config = vllm_config.additional_config - if additional_config and additional_config.get( - "ascend_scheduler_config", None) is not None: - self.use_v0_scheduler = True - else: - self.use_v0_scheduler = False - self.graph_block_tables = np.zeros( (self.vllm_config.scheduler_config.max_num_seqs, (self.model_config.max_model_len + self.block_size - 1) // @@ -326,25 +320,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.attn_mask_len, self.dtype) self.sampler = Sampler() - self.enable_torchair_graph_mode = False - self.use_cached_npu_graph = False - self.torchair_graph_batch_sizes = [] - additional_config = vllm_config.additional_config - if additional_config: - self.enable_torchair_graph_mode = additional_config.get( - "enable_graph_mode", - False) and self.vllm_config.model_config.use_mla - self.use_cached_npu_graph = additional_config.get( - "use_cached_npu_graph", False) - self.torchair_graph_batch_sizes = additional_config.get( - "torchair_graph_batch_sizes", []) - if not isinstance(self.torchair_graph_batch_sizes, list): - logger.warning("torchair_graph_batch_sizes must be list[int]") - self.torchair_graph_batch_sizes = [] - if len(self.torchair_graph_batch_sizes - ) == 0 and additional_config.get( - "torchair_graph_batch_sizes_init", False): - self.init_torchair_graph_batch_sizes() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla + self.torchair_graph_use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph + self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes + + if ascend_config.torchair_graph_config.graph_batch_sizes_init: + self.init_torchair_graph_batch_sizes() if len(self.torchair_graph_batch_sizes) == 0: #If MC2 is enabled, torchair_graph_batch_size should pad to tp_size @@ -628,13 +611,14 @@ def _process_reqs( block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) + ascend_config = get_ascend_config() if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly # splitfuse - elif not self.use_v0_scheduler or self.chunked_prefill_enabled: + elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.PrefillCacheHit @@ -671,7 +655,7 @@ def _process_reqs( extra_builder_kwargs['with_prefill_across_dp'] = with_prefill # Add graph_pad_size here - if envs_ascend.VLLM_ENABLE_MC2 or (self.enable_torchair_graph_mode + if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled and not with_prefill): batch_size = len(seq_lens) if self.dp_size > 1: @@ -715,7 +699,7 @@ def _process_reqs( input_ids = self.input_ids[:num_input_tokens] if (envs_ascend.VLLM_ENABLE_MC2 - or self.enable_torchair_graph_mode) and not with_prefill: + or self.torchair_graph_enabled) and not with_prefill: input_ids = self.input_ids[:padded_batch_size] positions = self.positions[:padded_batch_size] @@ -724,10 +708,10 @@ def _process_reqs( self.vllm_config, num_tokens=num_input_tokens): model_kwargs = {} - if self.enable_torchair_graph_mode: + if self.torchair_graph_enabled: model_kwargs["kv_caches"] = self.kv_caches model_kwargs["attn_metadata"] = attn_metadata - if self.enable_torchair_graph_mode and not with_prefill: + if self.torchair_graph_enabled and not with_prefill: hidden_states = self.compile_model( input_ids=input_ids, positions=positions, @@ -1170,7 +1154,7 @@ def _dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - if self.enable_torchair_graph_mode and not with_prefill: + if self.torchair_graph_enabled and not with_prefill: attn_metadata = self.attn_metadata_builder.build_dummy( num_reqs=num_tokens, num_actual_tokens=1) # Only mark static while compiling @@ -1262,7 +1246,7 @@ def load_model(self) -> None: m.consumed_memory / float(2**30)) # adapter torch compile with npu_backend - if self.enable_torchair_graph_mode: + if self.torchair_graph_enabled: import torchair # type: ignore from torchair import patch_for_hcom # type: ignore @@ -1339,7 +1323,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - if self.enable_torchair_graph_mode: + if self.torchair_graph_enabled: layer_kv_cache_nope = torch.zeros( kv_cache_shape[:-1] + (self.model_config.hf_text_config.kv_lora_rank, ), @@ -1417,7 +1401,7 @@ def capture_model(self) -> None: # TODO(NeverRaR): Calling graph_capture(device=self.device) in # torchair graph capture can cause some issues, so now we just # temporarily split the codepath for the two different graph patterns. - if self.enable_torchair_graph_mode: + if self.torchair_graph_enabled: torchair_graph_batch_sizes = self.torchair_graph_batch_sizes graph_num = len(torchair_graph_batch_sizes) logger.info( @@ -1449,10 +1433,7 @@ def capture_model(self) -> None: self._dummy_run(num_tokens) self._dummy_run(num_tokens) else: - logger.warning( - "Skipping NPU graph capture. Please add -O %s to use ACL graphs. " - "Or add --additional_config={'enable_graph_mode': True} to use torchair graphs", - CompilationLevel.PIECEWISE) + logger.info("Skipping NPU graph capture for eager mode.") return end_time = time.perf_counter() end_free_npu_memory = torch.npu.mem_get_info()[0] diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 84abe039f3..7070ea148d 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -47,6 +47,7 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) +from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform @@ -75,6 +76,9 @@ def __init__(self, # Register ops when worker init. from vllm_ascend import ops # noqa: F401 + # init ascend config + init_ascend_config(vllm_config) + WorkerBase.__init__(self, vllm_config=vllm_config) # Try to import mindie_turbo to accelerate vLLM inference. try_register_lib( diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index ad7440dd19..4d01f84d38 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -42,6 +42,7 @@ from vllm.v1.worker.worker_base import WorkerBase import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import try_register_lib @@ -67,6 +68,8 @@ def __init__( from vllm_ascend import ops ops.register_dummy_fusion_op() _register_atb_extensions() + # init ascend config + init_ascend_config(vllm_config) super().__init__(vllm_config=vllm_config, local_rank=local_rank, @@ -236,7 +239,7 @@ def execute_dummy_batch(self) -> None: if runner.dp_size > 1: max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( 1, False) - if envs_ascend.VLLM_ENABLE_MC2 or runner.enable_torchair_graph_mode: + if envs_ascend.VLLM_ENABLE_MC2 or runner.torchair_graph_enabled: if not with_prefill: num_tokens = max_num_tokens num_tokens = runner.select_torchair_padded_batch_size(num_tokens)