@@ -292,23 +292,6 @@ def test_check_and_update_config_no_model_config_warning(
292292 self .platform .check_and_update_config (self .mock_vllm_config )
293293 self .assertTrue ("Model config is missing" in cm .output [0 ])
294294
295- @patch ("vllm_ascend.utils.is_310p" , return_value = False )
296- @patch ("vllm_ascend.ascend_config.check_ascend_config" )
297- @patch ("vllm_ascend.ascend_config.init_ascend_config" )
298- @patch ("vllm.envs.VLLM_MLA_DISABLE" , True )
299- def test_check_and_update_config_torchair_graph_disabled_when_mla_disabled (
300- self , mock_init_ascend , mock_check_ascend , mock_is_310p ):
301- self .mock_ascend_config .torchair_graph_config .enabled = True
302- mock_init_ascend .return_value = self .mock_ascend_config
303-
304- from vllm_ascend import platform
305-
306- importlib .reload (platform )
307-
308- self .platform .check_and_update_config (self .mock_vllm_config )
309-
310- self .assertFalse (self .mock_ascend_config .torchair_graph_config .enabled )
311-
312295 @patch ("vllm_ascend.utils.is_310p" , return_value = False )
313296 @patch ("vllm_ascend.ascend_config.check_ascend_config" )
314297 @patch ("vllm_ascend.ascend_config.init_ascend_config" )
@@ -502,7 +485,13 @@ def test_check_and_update_config_ascend_scheduler_config(
502485 self .platform .check_and_update_config (self .mock_vllm_config )
503486 mock_scheduler .initialize_from_config .assert_called_once ()
504487
505- def test_get_attn_backend_cls_use_v1_and_mla (self ):
488+ @patch ('vllm_ascend.platform.get_ascend_config' )
489+ def test_get_attn_backend_cls_use_v1_and_mla (self , mock_get_ascend_config ):
490+ mock_config = MagicMock ()
491+ mock_config .torchair_graph_config .enabled = False
492+
493+ mock_get_ascend_config .return_value = mock_config
494+
506495 result = self .platform .get_attn_backend_cls (
507496 selected_backend = "ascend" ,
508497 head_size = 64 ,
@@ -514,8 +503,34 @@ def test_get_attn_backend_cls_use_v1_and_mla(self):
514503 )
515504 self .assertEqual (result ,
516505 "vllm_ascend.attention.mla_v1.AscendMLABackend" )
506+
507+ @patch ('vllm_ascend.platform.get_ascend_config' )
508+ def test_get_attn_backend_cls_use_v1_and_torchair (self , mock_get_ascend_config ):
509+ mock_config = MagicMock ()
510+ mock_config .torchair_graph_config .enabled = True
511+
512+ mock_get_ascend_config .return_value = mock_config
513+
514+ result = self .platform .get_attn_backend_cls (
515+ selected_backend = "ascend" ,
516+ head_size = 64 ,
517+ dtype = "float16" ,
518+ kv_cache_dtype = "float16" ,
519+ block_size = 64 ,
520+ use_v1 = True ,
521+ use_mla = False ,
522+ )
523+ self .assertEqual (
524+ result ,
525+ "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend" )
526+
527+ @patch ('vllm_ascend.platform.get_ascend_config' )
528+ def test_get_attn_backend_cls_use_v1_only (self , mock_get_ascend_config ):
529+ mock_config = MagicMock ()
530+ mock_config .torchair_graph_config .enabled = False
531+
532+ mock_get_ascend_config .return_value = mock_config
517533
518- def test_get_attn_backend_cls_use_v1_only (self ):
519534 result = self .platform .get_attn_backend_cls (
520535 selected_backend = "ascend" ,
521536 head_size = 64 ,
@@ -529,7 +544,13 @@ def test_get_attn_backend_cls_use_v1_only(self):
529544 result ,
530545 "vllm_ascend.attention.attention_v1.AscendAttentionBackend" )
531546
532- def test_get_attn_backend_cls_use_mla_only (self ):
547+ @patch ('vllm_ascend.platform.get_ascend_config' )
548+ def test_get_attn_backend_cls_use_mla_only (self , mock_get_ascend_config ):
549+ mock_config = MagicMock ()
550+ mock_config .torchair_graph_config .enabled = False
551+
552+ mock_get_ascend_config .return_value = mock_config
553+
533554 result = self .platform .get_attn_backend_cls (
534555 selected_backend = "ascend" ,
535556 head_size = 64 ,
@@ -543,7 +564,13 @@ def test_get_attn_backend_cls_use_mla_only(self):
543564 result ,
544565 "vllm_ascend.attention.attention.AscendMLAAttentionBackend" )
545566
546- def test_get_attn_backend_cls_default_case (self ):
567+ @patch ('vllm_ascend.platform.get_ascend_config' )
568+ def test_get_attn_backend_cls_default_case (self , mock_get_ascend_config ):
569+ mock_config = MagicMock ()
570+ mock_config .torchair_graph_config .enabled = False
571+
572+ mock_get_ascend_config .return_value = mock_config
573+
547574 result = self .platform .get_attn_backend_cls (
548575 selected_backend = "ascend" ,
549576 head_size = 64 ,
0 commit comments