@@ -564,7 +564,7 @@ def test_load_default_pipelines_pt(self):
564564 # test table in separate test due to more dependencies
565565 continue
566566
567- self .check_default_pipeline (task , "pt" , set_seed_fn , self .check_models_equal_pt )
567+ self .check_default_pipeline (task , set_seed_fn , self .check_models_equal_pt )
568568
569569 # clean-up as much as possible GPU memory occupied by PyTorch
570570 gc .collect ()
@@ -576,7 +576,7 @@ def test_load_default_pipelines_pt_table_qa(self):
576576 import torch
577577
578578 set_seed_fn = lambda : torch .manual_seed (0 ) # noqa: E731
579- self .check_default_pipeline ("table-question-answering" , "pt" , set_seed_fn , self .check_models_equal_pt )
579+ self .check_default_pipeline ("table-question-answering" , set_seed_fn , self .check_models_equal_pt )
580580
581581 # clean-up as much as possible GPU memory occupied by PyTorch
582582 gc .collect ()
@@ -624,17 +624,17 @@ def test_bc_torch_device(self):
624624 self .assertEqual (k1 , k2 )
625625 self .assertEqual (v1 .dtype , v2 .dtype )
626626
627- def check_default_pipeline (self , task , framework , set_seed_fn , check_models_equal_fn ):
627+ def check_default_pipeline (self , task , set_seed_fn , check_models_equal_fn ):
628628 from transformers .pipelines import SUPPORTED_TASKS , pipeline
629629
630630 task_dict = SUPPORTED_TASKS [task ]
631631 # test to compare pipeline to manually loading the respective model
632632 model = None
633- relevant_auto_classes = task_dict [framework ]
633+ relevant_auto_classes = task_dict ["pt" ]
634634
635635 if len (relevant_auto_classes ) == 0 :
636636 # task has no default
637- self .skipTest (f"{ task } in { framework } has no default" )
637+ self .skipTest (f"{ task } in pytorch has no default" )
638638
639639 # by default use first class
640640 auto_model_cls = relevant_auto_classes [0 ]
@@ -646,14 +646,14 @@ def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equa
646646 revisions = []
647647 tasks = []
648648 for translation_pair in task_dict ["default" ]:
649- model_id , revision = task_dict ["default" ][translation_pair ]["model" ][ framework ]
649+ model_id , revision = task_dict ["default" ][translation_pair ]["model" ]
650650
651651 model_ids .append (model_id )
652652 revisions .append (revision )
653653 tasks .append (task + f"_{ '_to_' .join (translation_pair )} " )
654654 else :
655655 # normal case - non-translation pipeline
656- model_id , revision = task_dict ["default" ]["model" ][ framework ]
656+ model_id , revision = task_dict ["default" ]["model" ]
657657
658658 model_ids = [model_id ]
659659 revisions = [revision ]
0 commit comments