diff --git a/.github/workflows/doc_codespell.yaml b/.github/workflows/doc_codespell.yaml index 930603c6b2..156ad71e59 100644 --- a/.github/workflows/doc_codespell.yaml +++ b/.github/workflows/doc_codespell.yaml @@ -28,6 +28,6 @@ jobs: - name: Run codespell check run: | CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**') - CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn') + CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever') codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}" diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 039c0332c5..7bb6b6ac22 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -86,7 +86,7 @@ jobs: - name: Run codespell check run: | CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**') - CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn') + CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever') codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}" - name: Analysing the code with ruff diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index e755b93796..e1b13bfccc 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -40,14 +40,14 @@ The details of each config option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | -| `enabled` | bool | `False` | Whether to enable torchair graph mode | -| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream | -| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert | +| `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode | +| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream. This option only takes effects on models using MLA (e.g., DeepSeek). | +| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on DeepSeek moe models. | | `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization | | `use_cached_graph` | bool | `False` | Whether to use cached graph | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | -| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout | +| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | **ascend_scheduler_config** diff --git a/docs/source/user_guide/graph_mode.md b/docs/source/user_guide/graph_mode.md index 77e91dd62a..a390bdbcc9 100644 --- a/docs/source/user_guide/graph_mode.md +++ b/docs/source/user_guide/graph_mode.md @@ -12,7 +12,7 @@ From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by defa There are two kinds for graph mode supported by vLLM Ascend: - **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, only Qwen series models are well tested. -- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported. +- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported. In v0.9.1rc2, we also support PanguProMoe with torchair. ## Using ACLGraph ACLGraph is enabled by default. Take Qwen series models as an example, just set to use V1 Engine is enough. diff --git a/format.sh b/format.sh index de083c1719..32569e2c78 100755 --- a/format.sh +++ b/format.sh @@ -145,7 +145,7 @@ CODESPELL_EXCLUDES=( ) CODESPELL_IGNORE_WORDS=( - '-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn' + '-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn,rever' ) # check spelling of specified files diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 04f4488d66..341c5bfc45 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -165,3 +165,20 @@ def test_models_distributed_DeepSeek_W8A8(): quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_pangu(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download("vllm-ascend/pangu-pro-moe-pruing"), + max_model_len=8192, + enforce_eager=True, + dtype="auto", + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index a0ae86085f..ce628f9d35 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -99,3 +99,63 @@ def test_e2e_deepseekv3_with_torchair_ms_mla(): }, } _deepseek_torchair_test_fixture(additional_config) + + +def _pangu_torchair_test_fixture( + additional_config: Dict, + *, + tensor_parallel_size=4, +): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # torchair is only work without chunked-prefill now + kwargs = { + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } + additional_config.update(**kwargs) + + with VllmRunner( + "vllm-ascend/pangu-pro-moe-pruing", + dtype="half", + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="mp", + enforce_eager=False, + additional_config=additional_config, + ) as vllm_model: + # use greedy sampler to make sure the generated results are fix + vllm_output = vllm_model.generate_greedy(example_prompts, 5) + + # NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE + # with 2 hidden layers, thus the golden results seems inaccurate. + # This will only change if accuracy changes with the official weights + # of PanguProMoE. + golden_results = [ + 'Hello, my name is Remempondeprecatedmiot忱', + 'The president of the United States is Remem下的一个 rever ceremoni Segnali', + 'The capital of France is Rememvoud administrativ Remem投', + 'The future of AI isotope Segnali Zoeken精细化 supus', + ] + + assert len(golden_results) == len(vllm_output) + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="torchair graph is not supported on v0") +def test_e2e_pangu_with_torchair(): + additional_config = { + "torchair_graph_config": { + "enabled": True, + }, + } + _pangu_torchair_test_fixture(additional_config) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py new file mode 100644 index 0000000000..91c2ad40df --- /dev/null +++ b/tests/ut/ops/test_rotary_embedding.py @@ -0,0 +1,315 @@ +import math +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled, + native_rope_deepseek_forward, + rope_forward_oot, rotate_half, + yarn_find_correction_dim, + yarn_get_mscale) + + +class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): + + def setUp(self): + # Common setup for tests + self.positions = torch.tensor([1, 2, 3]) + self.query = torch.randn(3, 4, dtype=torch.float16) + self.key = torch.randn(3, 4, dtype=torch.float16) + self.head_size = 32 + self.cos_sin_cache = torch.randn(3, 4) + + # Mock self object for rope_forward_oot + self.mock_self = MagicMock() + self.mock_self.head_size = self.head_size + self.mock_self.cos_sin_cache = self.cos_sin_cache + self.mock_self.is_neox_style = True + self.mock_self.forward_native.return_value = (self.query, self.key) + + def test_custom_rotary_embedding_enabled(self): + # Test when all conditions are True + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size) + self.assertTrue(result) + + # Test when dtype is not float16 + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=True): + query = self.query.to(torch.float32) + result = custom_rotary_embedding_enabled(query, True, + self.head_size) + self.assertFalse(result) + + # Test when neox_style is False + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, False, + self.head_size) + self.assertFalse(result) + + # Test when head_size is not divisible by 32 + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=True): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size + 1) + self.assertFalse(result) + + # Test when custom op is disabled + with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op', + return_value=False): + result = custom_rotary_embedding_enabled(self.query, True, + self.head_size) + self.assertFalse(result) + + +class TestRopeForwardOot(unittest.TestCase): + + def setUp(self): + # Common setup for tests + self.positions = torch.tensor([1, 2, 3]) + self.query = torch.randn(3, 4, dtype=torch.float16) + self.key = torch.randn(3, 4, dtype=torch.float16) + self.head_size = 32 + self.cos_sin_cache = torch.randn(3, 4) + + # Mock self object for rope_forward_oot + self.mock_self = MagicMock() + self.mock_self.head_size = self.head_size + self.mock_self.cos_sin_cache = self.cos_sin_cache + self.mock_self.is_neox_style = True + self.mock_self.forward_native.return_value = (self.query, self.key) + + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + def test_rope_forward_oot_torchair_enabled_base(self, + mock_get_ascend_config): + # Setup mock for torchair enabled + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = True + mock_get_ascend_config.return_value = mock_config + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + self.query, self.key) + + self.mock_self.forward_native.assert_called_once_with( + self.positions, self.query, self.key, None) + self.assertTrue(torch.equal(result_q, self.query)) + self.assertTrue(torch.equal(result_k, self.key)) + + @patch('torch.ops._C') + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + @patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False) + @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + return_value=True) + @patch('torch.ops._npu_rotary_embedding') + def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, + mock_custom_enabled, mock_is_310p, + mock_get_ascend_config, mock__c): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Setup mock for custom kernel path + + mock__c.rotary_embedding.return_value = self.query, self.key + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + self.query, self.key) + + self.assertEqual(result_q.shape, self.query.shape) + self.assertEqual(result_k.shape, self.key.shape) + + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + return_value=False) + @patch('torch_npu._npu_rotary_embedding') + def test_rope_forward_oot_contiguous(self, mock_npu_rotary, + mock_custom_enabled, + mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test contiguous path when custom is disabled + non_contig_query = self.query.transpose(0, 1) + non_contig_key = self.key.transpose(0, 1) + + result_q, result_k = rope_forward_oot(self.mock_self, self.positions, + non_contig_query, non_contig_key) + + mock_npu_rotary.assert_called_once() + self.assertEqual(result_q.shape, non_contig_query.shape) + self.assertEqual(result_k.shape, non_contig_key.shape) + + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test that NotImplementedError is raised when offsets is provided + offsets = torch.tensor([1, 2, 3]) + with self.assertRaises(NotImplementedError): + rope_forward_oot(self.mock_self, self.positions, self.query, + self.key, offsets) + + @patch('vllm_ascend.ops.rotary_embedding.get_ascend_config') + @patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled', + return_value=False) + @patch('torch_npu._npu_rotary_embedding') + def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, + mock_custom_enabled, + mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_ascend_config.return_value = mock_config + + # Test neox_style override + result_q, result_k = rope_forward_oot(self.mock_self, + self.positions, + self.query, + self.key, + is_neox_style_override=False) + + # Check that neox_style=False was passed to the NPU function + args, kwargs = mock_npu_rotary.call_args + self.assertFalse(args[-1]) + + +class MockRopeModule: + + def __init__(self, max_seq_len=2048, is_neox_style=True): + self.max_seq_len = max_seq_len + self.is_neox_style = is_neox_style + self.cos_cached = None + self.sin_cached = None + self.rotary_dim = 1 + self.base = 1 + + +class TestNativeRopeDeepseekForward(TestBase): + + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot): + module = MockRopeModule() + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + @patch('vllm_ascend.ops.rotary_embedding._set_cos_sin_cache') + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_cache_handling( + self, mock_rope_forward_oot, mock_set_cache): + # Test cache situation is true + module = MockRopeModule(max_seq_len=1024) + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, + positions, + query, + key, + max_seq_len=2048) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_key_reshaping( + self, mock_rope_forward_oot): + module = MockRopeModule() + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == (1, 128) + + @patch('vllm_ascend.ops.rotary_embedding.rope_forward_oot') + def test_native_rope_deepseek_forward_non_neox_style( + self, mock_rope_forward_oot): + module = MockRopeModule(is_neox_style=False) + positions = torch.tensor([1, 2, 3]) + query = torch.randn(1, 8, 128) + key = torch.randn(1, 8, 128) + + mock_rope_forward_oot.return_value = (query, key) + + q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, + key) + + assert q_pe.shape == query.shape + assert k_pe.shape == key.shape + + +class TestRotateHalf(unittest.TestCase): + + def test_rotate_half_even_dim(self): + # Test with even dimension + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + expected = torch.tensor([-3.0, -4.0, 1.0, 2.0]) + result = rotate_half(x) + self.assertTrue(torch.allclose(result, expected)) + + +class TestYarnFindCorrectionDim(unittest.TestCase): + + def test_basic_case(self): + # Test with standard values + num_rotations = 100 + dim = 512 + base = 10000 + max_position_embeddings = 2048 + + result = yarn_find_correction_dim(num_rotations, dim, base, + max_position_embeddings) + + # Calculate expected value manually + expected = (dim * torch.log( + torch.tensor(max_position_embeddings) / + (num_rotations * 2 * torch.pi))) / (2 * + torch.log(torch.tensor(base))) + + self.assertTrue(torch.allclose(result, expected)) + + +class TestYarnGetMscale(unittest.TestCase): + + def test_scale_less_than_or_equal_1(self): + self.assertEqual(yarn_get_mscale(scale=0.5), 1.0) + self.assertEqual(yarn_get_mscale(scale=1.0), 1.0) + self.assertEqual(yarn_get_mscale(scale=0.999), 1.0) + + def test_scale_greater_than_1(self): + test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)), + (10.0, 1.0, 1.0 + 0.1 * math.log(10.0)), + (5.0, 2.0, 1.0 + 0.2 * math.log(5.0)), + (math.e, 1.0, 1.0 + 0.1)] + + for scale, mscale, expected in test_cases: + result = yarn_get_mscale(scale, mscale) + self.assertAlmostEqual( + result, + expected, + places=6, + msg=f"Failed for scale={scale}, mscale={mscale}") diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 21e0626545..5ce9c34489 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -316,6 +316,58 @@ def test_apply_decode_only(self, mock_quant, mock_scatter): self.assertEqual(mock_scatter.call_count, 2) self.assertTrue(torch.equal(result, expected_output)) + @patch('torch_npu.npu_scatter_nd_update_') + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + def test_apply_attn_metadata_without_decode(self, mock_quant, + mock_scatter): + + num_tokens = 2 + query = torch.randn(num_tokens, + self.layer.num_heads * self.layer.head_size) + key = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + value = torch.randn(num_tokens, + self.layer.num_kv_heads * self.layer.head_size) + output = torch.empty_like(query) + + attn_metadata = MagicMock(spec=[ + 'attn_state', 'seq_lens', 'block_tables', 'slot_mapping', + 'attn_mask' + ]) + attn_metadata.attn_state = AscendAttentionState.DecodeOnly + attn_metadata.seq_lens = [10, 10] + attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]]) + attn_metadata.slot_mapping = torch.tensor([0, 1]) + attn_metadata.attn_mask = None + + block_size = 16 + key_cache = torch.empty(2, block_size, self.layer.num_kv_heads, + self.layer.head_size) + value_cache = torch.empty(2, block_size, self.layer.num_kv_heads, + self.layer.head_size) + kv_cache = (key_cache, value_cache) + + mock_quant.side_effect = [key, value] + + self.layer.key_antiquant_scale.data = torch.ones( + self.layer.num_kv_heads * self.layer.head_size) + self.layer.value_antiquant_scale.data = torch.ones( + self.layer.num_kv_heads * self.layer.head_size) + self.method.process_weights_after_loading(self.layer) + + expected_output = torch.randn( + num_tokens, self.layer.num_heads * self.layer.head_size) + with patch('torch_npu.npu_incre_flash_attention', + return_value=expected_output): + result = self.method.apply(self.layer, query, key, value, kv_cache, + attn_metadata, + self.attention_type.DECODER, 1.0, + output) + + self.assertEqual(mock_quant.call_count, 2) + self.assertEqual(mock_scatter.call_count, 2) + self.assertTrue(torch.equal(result, expected_output)) + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") @patch('torch_npu._npu_flash_attention') def test_apply_prefill_no_cache(self, mock_flash, mock_quant): diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 5ec4dd72cc..6146960b91 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -6,6 +6,7 @@ from vllm.config import ModelConfig, VllmConfig from vllm_ascend.ascend_config import (check_ascend_config, + check_torchair_supported, clear_ascend_config, get_ascend_config, init_ascend_config) @@ -242,3 +243,10 @@ def test_check_ascend_config_wrong_case(self): test_vllm_config.model_config = fake_model_config init_ascend_config(test_vllm_config) check_ascend_config(test_vllm_config, False) + + def test_check_torchair_supported(self): + test_cases = [('deepseek_v3', True), ('PanguProMoE', True), + ('qwen', False), ('llama', False)] + for model_type, expected_output in test_cases: + self.assertEqual(check_torchair_supported(model_type), + expected_output) diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 77aa4f3280..c09964a745 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -292,23 +292,6 @@ def test_check_and_update_config_no_model_config_warning( self.platform.check_and_update_config(self.mock_vllm_config) self.assertTrue("Model config is missing" in cm.output[0]) - @patch("vllm_ascend.utils.is_310p", return_value=False) - @patch("vllm_ascend.ascend_config.check_ascend_config") - @patch("vllm_ascend.ascend_config.init_ascend_config") - @patch("vllm.envs.VLLM_MLA_DISABLE", True) - def test_check_and_update_config_torchair_graph_disabled_when_mla_disabled( - self, mock_init_ascend, mock_check_ascend, mock_is_310p): - self.mock_ascend_config.torchair_graph_config.enabled = True - mock_init_ascend.return_value = self.mock_ascend_config - - from vllm_ascend import platform - - importlib.reload(platform) - - self.platform.check_and_update_config(self.mock_vllm_config) - - self.assertFalse(self.mock_ascend_config.torchair_graph_config.enabled) - @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") @@ -502,7 +485,13 @@ def test_check_and_update_config_ascend_scheduler_config( self.platform.check_and_update_config(self.mock_vllm_config) mock_scheduler.initialize_from_config.assert_called_once() - def test_get_attn_backend_cls_use_v1_and_mla(self): + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + mock_get_ascend_config.return_value = mock_config + result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, @@ -515,7 +504,35 @@ def test_get_attn_backend_cls_use_v1_and_mla(self): self.assertEqual(result, "vllm_ascend.attention.mla_v1.AscendMLABackend") - def test_get_attn_backend_cls_use_v1_only(self): + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_v1_and_torchair(self, + mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = True + + mock_get_ascend_config.return_value = mock_config + + result = self.platform.get_attn_backend_cls( + selected_backend="ascend", + head_size=64, + dtype="float16", + kv_cache_dtype="float16", + block_size=64, + use_v1=True, + use_mla=False, + ) + self.assertEqual( + result, + "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend" + ) + + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + mock_get_ascend_config.return_value = mock_config + result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, @@ -529,7 +546,13 @@ def test_get_attn_backend_cls_use_v1_only(self): result, "vllm_ascend.attention.attention_v1.AscendAttentionBackend") - def test_get_attn_backend_cls_use_mla_only(self): + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_use_mla_only(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + mock_get_ascend_config.return_value = mock_config + result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, @@ -543,7 +566,13 @@ def test_get_attn_backend_cls_use_mla_only(self): result, "vllm_ascend.attention.attention.AscendMLAAttentionBackend") - def test_get_attn_backend_cls_default_case(self): + @patch('vllm_ascend.platform.get_ascend_config') + def test_get_attn_backend_cls_default_case(self, mock_get_ascend_config): + mock_config = MagicMock() + mock_config.torchair_graph_config.enabled = False + + mock_get_ascend_config.return_value = mock_config + result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index d8b87c6952..c5c4d125d0 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -18,6 +18,15 @@ import vllm.envs as envs from vllm.logger import logger +TORCHAIR_MODEL_LIST = ["deepseek", "pangu"] + + +def check_torchair_supported(model_type: str): + for supported_model in TORCHAIR_MODEL_LIST: + if supported_model in model_type.lower(): + return True + return False + class AscendConfig: """ @@ -141,10 +150,10 @@ def check_ascend_config(vllm_config, enforce_eager): # torchair_graph is supported for deepseek model only currently. if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type - if "deepseek" not in model_type: + if not check_torchair_supported(model_type): raise NotImplementedError( - "Torchair graph mode only works with deepseek model." - ) + "Torchair graph mode only works with following model types:" + f"{TORCHAIR_MODEL_LIST}.") # aclgraph case else: # aclgraph doesn't work with deepseek model and only qwen model is well tested. diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py new file mode 100644 index 0000000000..ef810ba5d3 --- /dev/null +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -0,0 +1,506 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import numpy as np +import torch +import torch_npu +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionType) +from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, + nd_to_nz_2d) + + +class AscendAttentionTorchairBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND" + + @staticmethod + def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]: + return AscendAttentionTorchairBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AscendTorchairMetadata"]: + return AscendTorchairMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]: + return AscendAttentionTorchairMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads * head_size) + + @staticmethod + def get_bsh_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads * head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: List[torch.Tensor], + dst_kv_cache: List[torch.Tensor], + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] + dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1] + src_indices = src_to_dst[:, 0] + dst_indices = src_to_dst[:, 1] + + dst_key_cache[dst_indices] = src_key_cache[src_indices].to( + dst_key_cache.device) + dst_value_cache[dst_indices] = src_value_cache[src_indices].to( + dst_key_cache.device) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + src_indices = src_to_dists[:, 0] + dst_indices = src_to_dists[:, 1] + + for kv_cache in kv_caches: + key_caches = kv_cache[0] + value_caches = kv_cache[1] + key_caches[dst_indices] = key_caches[src_indices] + value_caches[dst_indices] = value_caches[src_indices] + + +@dataclass +class AscendDecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + max_seq_lens: int + seq_lens_list: list[int] + attn_mask: Optional[torch.Tensor] = None + + +@dataclass +class AscendTorchairMetadata: + num_actual_tokens: int # Number of tokens excluding padding. + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + block_tables: torch.Tensor + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + query_start_loc: torch.Tensor + query_lens: torch.Tensor + seq_lens: torch.Tensor + + # max value of number of tokens across dp group + max_num_tokens_across_dp: int = 0 + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor = None + # TODO: Indicates whether there are only prefill requests. + # FlashAttention can be used when there are only prefill requests. + # FlashAttention has better performance than PageAtttention, + # but it does not support decode requests. + is_only_prefill: bool = False + # Current state of this attention run. + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + attn_mask: Optional[torch.Tensor] = None + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + with_prefill_across_dp: bool = False + + decode: Optional[AscendDecodeMetadata] = None + + +class AscendAttentionTorchairMetadataBuilder: + + def __init__(self, runner): + self.runner = runner + self.torchair_graph_enabled = get_ascend_config( + ).torchair_graph_config.enabled + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False + + def _get_graph_runner_block_tables( + self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: + + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + if isinstance(self.runner.graph_block_tables, np.ndarray): + graph_block_tables = torch.zeros((max_batch_size, max_blocks), + dtype=block_tables.dtype, + device=block_tables.device) + else: + graph_block_tables = self.runner.graph_block_tables.to( + device=block_tables.device, dtype=block_tables.dtype) + + num_blocks = block_tables.size(1) + if num_blocks <= max_blocks: + graph_block_tables[:num_seqs, : + num_blocks] = block_tables[:num_seqs, : + num_blocks] + else: + graph_block_tables[:num_seqs, : + max_blocks] = block_tables[:num_seqs, : + max_blocks] + + return graph_block_tables[:num_seqs, :max_blocks] + + def build_dummy(self, num_reqs: int, + num_actual_tokens: int) -> AscendTorchairMetadata: + device = self.runner.device + _, max_blocks = self.runner.graph_block_tables.shape + block_table = torch.zeros((num_reqs, max_blocks), + dtype=torch.int32, + device=device) + block_table = self._get_graph_runner_block_tables( + num_reqs, block_table) + seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) + input_positions = torch.zeros(num_reqs, + dtype=torch.int32, + device=device).long() + slot_mapping = torch.full((num_reqs, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=device) + query_start_loc = torch.full((num_reqs, ), + -1, + dtype=torch.int32, + device=device) + + decode_metadata = AscendDecodeMetadata(input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), + max_seq_lens=1) + + attn_metadata = AscendTorchairMetadata( + num_actual_tokens=num_actual_tokens, + block_tables=block_table, + query_lens=0, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + slot_mapping=slot_mapping, + attn_state=AscendAttentionState.DecodeOnly, + max_num_tokens_across_dp=num_reqs, + decode=decode_metadata) + return attn_metadata + + def build(self, + num_reqs, + num_actual_tokens, + max_query_len, + common_prefix_len, + graph_pad_size: int = -1, + max_num_tokens_across_dp: int = 0, + with_prefill_across_dp: bool = False): + + device = self.runner.device + + block_table = self.runner.input_batch.block_table[0].get_device_tensor( + ) + block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + block_table[:num_reqs]) + + query_lens = self.runner.query_lens + seq_lens = self.runner.seq_lens_cpu[:num_reqs] + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + self.runner.device, non_blocking=True) + attn_mask = self.runner.attn_mask + + attn_state = self.runner.attn_state + if is_310p() and attn_state == AscendAttentionState.PrefillNoCache: + mask_nz = nd_to_nz_2d(attn_mask) + attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29) + + query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] + query_start_loc = query_start_loc_cpu.to(self.runner.device, + non_blocking=True) + input_positions = self.runner.positions_cpu[:num_actual_tokens].to( + device, non_blocking=True).long() + + decode_metadata = None + use_torchair_graph = graph_pad_size > -1 + if self.runner.attn_state in [ + AscendAttentionState.DecodeOnly, + ]: + max_seq_lens = seq_lens.max().item() + num_seqs = len(seq_lens) + if use_torchair_graph and self.runner.attn_state in [ + AscendAttentionState.DecodeOnly, + ]: + max_num_tokens_across_dp += graph_pad_size + pad_value = 1 + padded_seq_lens = seq_lens.tolist() + [pad_value + ] * graph_pad_size + + seq_lens = torch.from_numpy( + np.array(padded_seq_lens).astype(np.int32)) + padding = torch.full((graph_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, padding]) + block_table_padding = torch.zeros( + (graph_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + block_table = self._get_graph_runner_block_tables( + num_seqs + graph_pad_size, block_table) + padding_0 = torch.zeros(graph_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat([input_positions, padding_0]) + + decode_metadata = AscendDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), + max_seq_lens=max_seq_lens, + attn_mask=None) + + attn_metadata = AscendTorchairMetadata( + decode=decode_metadata, + num_actual_tokens=num_actual_tokens, + block_tables=block_table, + query_start_loc=query_start_loc, + query_lens=query_lens, + seq_lens=seq_lens, + max_query_len=max_query_len, + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state, + max_num_tokens_across_dp=max_num_tokens_across_dp, + with_prefill_across_dp=with_prefill_across_dp) + return attn_metadata + + +class AscendAttentionTorchairBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.hidden_size = self.num_heads * self.head_size + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = sliding_window + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, + dtype=torch.float32, + device="npu") + self.alibi_slopes = alibi_slopes + self.attn_type = attn_type + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.key_cache = None + self.value_cache = None + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AscendTorchairMetadata, + output: Optional[torch.Tensor] = None, + trace_flag: bool = False, + ) -> torch.Tensor: + """Forward pass with Ascend attention. + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache: shape = [2, num_blocks, block_size, + num_kv_heads, head_size] + key_cache = [num_blocks, block_size, + num_kv_heads, head_size] + value_cache = [num_blocks, block_size, + num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size * seq_len, num_heads, head_size] + """ + num_tokens = query.shape[0] + use_kv_cache_quant = kv_cache is not None and kv_cache[0].numel( + ) > 0 and kv_cache[0].dtype == torch.int8 + if output is None: + output = torch.empty(num_tokens, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device) + + if hasattr(layer, 'quant_method') and use_kv_cache_quant: + output = layer.quant_method.apply(layer, query, key, value, + kv_cache, attn_metadata, + self.attn_type, self.scale, + output) + return output.view(num_tokens, self.hidden_size) + + if attn_metadata is None: + return output.view(num_tokens, self.hidden_size) + + output = output.view(-1, self.num_heads, self.head_size) + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + attn_type = self.attn_type + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "AscendAttentionTorchairBackendImpl") + + if kv_cache is not None and kv_cache[0].numel() > 0: + key_cache, value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + + block_size = key_cache.shape[1] + slots_indices = slots.reshape(-1, 1) + block_indices = slots_indices // block_size + slots_indices = slots_indices % block_size + indices = torch.cat((block_indices, slots_indices), dim=1) + torch_npu.npu_scatter_nd_update_(key_cache, indices, key) + torch_npu.npu_scatter_nd_update_(value_cache, indices, value) + + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + + # View q k v to BSH. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if is_310p(): + # align q k v output tensors + query = aligned_16(query) + key = aligned_16(key) + value = aligned_16(value) + output = aligned_16(output) + + # do reformat in case of broadcasted tensors + mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1) + mask = torch_npu.npu_format_cast(mask.contiguous(), + ACL_FORMAT_FRACTAL_NZ) + + torch_npu._npu_flash_attention(query=query, + key=key, + value=value, + mask=mask, + seq_len=attn_metadata.seq_lens, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + out=output) + output = output[:num_tokens, :, :] + elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + assert attn_metadata is not None + assert attn_metadata.attn_mask is not None + compress_mask = attn_metadata.attn_mask + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + block_table=attn_metadata.block_tables, + mask=compress_mask, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + decode_meta = attn_metadata.decode + assert decode_meta is not None + seq_lens = decode_meta.seq_lens_list + block_table = decode_meta.block_table + block_size = key_cache.shape[1] + query = query.view(num_tokens, 1, + self.num_heads * self.head_size).contiguous() + output = torch_npu.npu_incre_flash_attention( + query, + key_cache, + value_cache, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + actual_seq_lengths=seq_lens, + scale_value=self.scale, + block_table=block_table, + input_layout='BSH', + block_size=block_size) + else: + raise NotImplementedError( + "Torchair graph mode with non-MLA attention backend is still experimental." + "v1 scheduler(chunked prefill) is not supported at this moment. Please" + "setting 'ascend_scheduler_config':{'enabled':true} in additional_config" + "to use ascend scheduler.") + + return output.view(num_tokens, self.hidden_size) diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index e01e409989..609c86f361 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -20,6 +20,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F +import torch_npu from torch import nn from torch.nn import Parameter from transformers import PretrainedConfig @@ -56,8 +57,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group -from vllm_ascend.utils import is_310p +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p logger = init_logger(__name__) @@ -498,8 +500,8 @@ def forward( global _ROUTER_SCALE _ROUTER_SCALE = self.router_scale if not use_h2p(): - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts.forward_impl( + hidden_states=hidden_states, router_logits=router_logits) else: # TODO: when using h2p, we have to skip communication in vLLM # native FusedMoE. here we need to design a better FusedMoE @@ -608,6 +610,9 @@ def __init__( prefix=f"{prefix}.attn", ) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + def forward( self, positions: torch.Tensor, @@ -618,7 +623,19 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) + if self.torchair_graph_enabled: + forward_kwargs = {'trace_flag': False} + output_shape = q.shape + attn_output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = attn_output + attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache, + attn_metadata, + **forward_kwargs) + else: + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) return output @@ -1097,4 +1114,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + if is_310p() and "head" in name: + # on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than + # ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented + # by linear, we manually cast the format here. + param.data = torch_npu.npu_format_cast(param.data, + ACL_FORMAT_FRACTAL_NZ) return loaded_params diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 27226798af..3dd91ea63f 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import enable_custom_op, is_310p @@ -38,6 +39,14 @@ def rope_forward_oot( offsets: Optional[torch.Tensor] = None, is_neox_style_override: Optional[bool] = None ) -> Tuple[torch.Tensor, torch.Tensor]: + if get_ascend_config().torchair_graph_config.enabled: + return self.forward_native( + positions, + query, + key, + offsets, + ) + import torch_npu query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index 91f43a3a8b..d244016076 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -132,19 +132,6 @@ def all_reduce( torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p( torch.distributed.distributed_c10d.all_reduce) - def reduce_scatter_310p(output_tensor, input_tensor, group=None): - rank = torch.distributed.get_rank(group) - world_size = torch.distributed.get_world_size(group) - torch.distributed.all_reduce(input_tensor, - torch.distributed.ReduceOp.SUM, - group, - async_op=False) - interval = input_tensor.shape[0] // world_size - output_tensor[:] = input_tensor[rank * interval:(rank + 1) * interval] - - torch.distributed._reduce_scatter_base = reduce_scatter_310p - torch.distributed.distributed_c10d._reduce_scatter_base = reduce_scatter_310p - if is_310p(): communication_adaptation_310p() diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 4c92abffb5..07fb07fcb6 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -27,7 +27,8 @@ from vllm.logger import logger from vllm.platforms import Platform, PlatformEnum -from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config +from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, + init_ascend_config) from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p, update_aclgraph_sizes) @@ -154,14 +155,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: enforce_eager = getattr(model_config, "enforce_eager", False) - if ascend_config.torchair_graph_config.enabled and envs.VLLM_MLA_DISABLE: - # torchair_graph is not supported for V1 without mla currently. - logger.warning( - "Torchair graph mode is still experimental and not supported for V1 without mla currently, " - "Fallback to eager mode.") - ascend_config.torchair_graph_config.enabled = False - enforce_eager = True - check_ascend_config(vllm_config, enforce_eager) if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION: @@ -229,6 +222,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): if use_v1 and use_mla: return "vllm_ascend.attention.mla_v1.AscendMLABackend" + use_torchair = get_ascend_config().torchair_graph_config.enabled + if use_v1 and use_torchair: + return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend" if use_v1: return "vllm_ascend.attention.attention_v1.AscendAttentionBackend" if use_mla: diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 074615040c..853fe6e63f 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -373,10 +373,12 @@ def apply(self, layer, query, key, value, kv_cache, attn_metadata, "implemented for " "PrefillCacheHit") elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly - # torch_air - # decode_meta = attn_metadata.decode - # seq_lens = decode_meta.seq_lens_list - seq_lens = attn_metadata.seq_lens + if hasattr(attn_metadata, "decode"): + # torch_air + decode_meta = attn_metadata.decode + seq_lens = decode_meta.seq_lens_list + else: + seq_lens = attn_metadata.seq_lens block_size = key_cache.shape[1] query = query.view(num_tokens, 1, layer.num_heads * layer.head_size).contiguous() # changed diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 78e05ed426..952a22734a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2035,9 +2035,19 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int): from torchair import patch_for_hcom # type: ignore patch_for_hcom() + + if is_310p(): + # on 300I Duo platform, we need to patch broadcast. however, this patch will be + # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. + from vllm_ascend.patch.platform.patch_common.patch_distributed import \ + communication_adaptation_310p + communication_adaptation_310p() + config = torchair.CompilerConfig() config.experimental_config.frozen_parameter = True - config.experimental_config.tiling_schedule_optimize = True + # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to + # disable it on 300I Duo platform now. + config.experimental_config.tiling_schedule_optimize = not is_310p() config.experimental_config.enable_view_optimize = \ get_ascend_config().torchair_graph_config.enable_view_optimize torch.npu.set_compile_mode(jit_compile=False) @@ -2135,27 +2145,50 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) if self.torchair_graph_enabled: - layer_kv_cache_nope = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.kv_lora_rank, ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - layer_kv_cache_pe = torch.zeros( - kv_cache_shape[:-1] + - (self.model_config.hf_text_config.qk_rope_head_dim, - ), - dtype=self.dtype, - pin_memory=True, - device=self.device) - kv_caches[layer_name] = (layer_kv_cache_nope, - layer_kv_cache_pe) - kv_caches[layer_name] = ( - torch_npu.npu_format_cast(kv_caches[layer_name][0], - acl_format), - torch_npu.npu_format_cast(kv_caches[layer_name][1], - acl_format), - ) + if len(kv_cache_shape) == 3: + # for non MLA attention backend that use torchair, we consider to pass kv_cache layout + # of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention. + + kv_caches[layer_name] = ( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device)) + # atb reshape_and_cache does not support torchair. + kv_caches[layer_name] = ( + torch_npu.npu_format_cast( + kv_caches[layer_name][0], + ACL_FORMAT_FRACTAL_ND), + torch_npu.npu_format_cast( + kv_caches[layer_name][1], + ACL_FORMAT_FRACTAL_ND), + ) + else: + # for MLA attention backend that use torchair. + layer_kv_cache_nope = torch.zeros( + kv_cache_shape[:-1] + + (self.model_config.hf_text_config.kv_lora_rank, + ), + dtype=self.dtype, + pin_memory=True, + device=self.device) + layer_kv_cache_pe = torch.zeros( + kv_cache_shape[:-1] + + (self.model_config.hf_text_config. + qk_rope_head_dim, ), + dtype=self.dtype, + pin_memory=True, + device=self.device) + kv_caches[layer_name] = (layer_kv_cache_nope, + layer_kv_cache_pe) + kv_caches[layer_name] = ( + torch_npu.npu_format_cast( + kv_caches[layer_name][0], acl_format), + torch_npu.npu_format_cast( + kv_caches[layer_name][1], acl_format), + ) else: kv_caches[layer_name] = torch.zeros( kv_cache_shape,