Skip to content

Commit 41618ed

Browse files
author
angazenn
committed
fix ut
Signed-off-by: angazenn <zengyanjia@huawei.com>
1 parent 0c4f231 commit 41618ed

File tree

1 file changed

+48
-21
lines changed

1 file changed

+48
-21
lines changed

tests/ut/test_platform.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -292,23 +292,6 @@ def test_check_and_update_config_no_model_config_warning(
292292
self.platform.check_and_update_config(self.mock_vllm_config)
293293
self.assertTrue("Model config is missing" in cm.output[0])
294294

295-
@patch("vllm_ascend.utils.is_310p", return_value=False)
296-
@patch("vllm_ascend.ascend_config.check_ascend_config")
297-
@patch("vllm_ascend.ascend_config.init_ascend_config")
298-
@patch("vllm.envs.VLLM_MLA_DISABLE", True)
299-
def test_check_and_update_config_torchair_graph_disabled_when_mla_disabled(
300-
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
301-
self.mock_ascend_config.torchair_graph_config.enabled = True
302-
mock_init_ascend.return_value = self.mock_ascend_config
303-
304-
from vllm_ascend import platform
305-
306-
importlib.reload(platform)
307-
308-
self.platform.check_and_update_config(self.mock_vllm_config)
309-
310-
self.assertFalse(self.mock_ascend_config.torchair_graph_config.enabled)
311-
312295
@patch("vllm_ascend.utils.is_310p", return_value=False)
313296
@patch("vllm_ascend.ascend_config.check_ascend_config")
314297
@patch("vllm_ascend.ascend_config.init_ascend_config")
@@ -502,7 +485,13 @@ def test_check_and_update_config_ascend_scheduler_config(
502485
self.platform.check_and_update_config(self.mock_vllm_config)
503486
mock_scheduler.initialize_from_config.assert_called_once()
504487

505-
def test_get_attn_backend_cls_use_v1_and_mla(self):
488+
@patch('vllm_ascend.platform.get_ascend_config')
489+
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
490+
mock_config = MagicMock()
491+
mock_config.torchair_graph_config.enabled = False
492+
493+
mock_get_ascend_config.return_value = mock_config
494+
506495
result = self.platform.get_attn_backend_cls(
507496
selected_backend="ascend",
508497
head_size=64,
@@ -514,8 +503,34 @@ def test_get_attn_backend_cls_use_v1_and_mla(self):
514503
)
515504
self.assertEqual(result,
516505
"vllm_ascend.attention.mla_v1.AscendMLABackend")
506+
507+
@patch('vllm_ascend.platform.get_ascend_config')
508+
def test_get_attn_backend_cls_use_v1_and_torchair(self, mock_get_ascend_config):
509+
mock_config = MagicMock()
510+
mock_config.torchair_graph_config.enabled = True
511+
512+
mock_get_ascend_config.return_value = mock_config
513+
514+
result = self.platform.get_attn_backend_cls(
515+
selected_backend="ascend",
516+
head_size=64,
517+
dtype="float16",
518+
kv_cache_dtype="float16",
519+
block_size=64,
520+
use_v1=True,
521+
use_mla=False,
522+
)
523+
self.assertEqual(
524+
result,
525+
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend")
526+
527+
@patch('vllm_ascend.platform.get_ascend_config')
528+
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
529+
mock_config = MagicMock()
530+
mock_config.torchair_graph_config.enabled = False
531+
532+
mock_get_ascend_config.return_value = mock_config
517533

518-
def test_get_attn_backend_cls_use_v1_only(self):
519534
result = self.platform.get_attn_backend_cls(
520535
selected_backend="ascend",
521536
head_size=64,
@@ -529,7 +544,13 @@ def test_get_attn_backend_cls_use_v1_only(self):
529544
result,
530545
"vllm_ascend.attention.attention_v1.AscendAttentionBackend")
531546

532-
def test_get_attn_backend_cls_use_mla_only(self):
547+
@patch('vllm_ascend.platform.get_ascend_config')
548+
def test_get_attn_backend_cls_use_mla_only(self, mock_get_ascend_config):
549+
mock_config = MagicMock()
550+
mock_config.torchair_graph_config.enabled = False
551+
552+
mock_get_ascend_config.return_value = mock_config
553+
533554
result = self.platform.get_attn_backend_cls(
534555
selected_backend="ascend",
535556
head_size=64,
@@ -543,7 +564,13 @@ def test_get_attn_backend_cls_use_mla_only(self):
543564
result,
544565
"vllm_ascend.attention.attention.AscendMLAAttentionBackend")
545566

546-
def test_get_attn_backend_cls_default_case(self):
567+
@patch('vllm_ascend.platform.get_ascend_config')
568+
def test_get_attn_backend_cls_default_case(self, mock_get_ascend_config):
569+
mock_config = MagicMock()
570+
mock_config.torchair_graph_config.enabled = False
571+
572+
mock_get_ascend_config.return_value = mock_config
573+
547574
result = self.platform.get_attn_backend_cls(
548575
selected_backend="ascend",
549576
head_size=64,

0 commit comments

Comments
 (0)