Skip to content

Commit d35c9a3

Browse files
committed
could it really be the end?
1 parent 9422073 commit d35c9a3

File tree

7 files changed

+13
-13
lines changed

7 files changed

+13
-13
lines changed

tests/models/layoutlmv2/test_tokenization_layoutlmv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2398,7 +2398,7 @@ def test_layoutlmv2_integration_test(self):
23982398
self.assertDictEqual(dict(encoding_p), expected_results)
23992399
self.assertDictEqual(dict(encoding_r), expected_results)
24002400

2401-
@unittest.skip(reason="Doesn't support another framework than PyTorch")
2401+
@unittest.skip(reason="Doesn't support returning Numpy arrays")
24022402
def test_np_encode_plus_sent_to_model(self):
24032403
pass
24042404

tests/models/layoutlmv3/test_tokenization_layoutlmv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2322,7 +2322,7 @@ def test_layoutlmv3_integration_test(self):
23222322
self.assertDictEqual(dict(encoding_p), expected_results)
23232323
self.assertDictEqual(dict(encoding_r), expected_results)
23242324

2325-
@unittest.skip(reason="Doesn't support another framework than PyTorch")
2325+
@unittest.skip(reason="Doesn't support returning Numpy arrays")
23262326
def test_np_encode_plus_sent_to_model(self):
23272327
pass
23282328

tests/models/layoutxlm/test_tokenization_layoutxlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ def test_layoutxlm_integration_test(self):
18841884
self.assertDictEqual(dict(encoding_p), expected_results)
18851885
self.assertDictEqual(dict(encoding_r), expected_results)
18861886

1887-
@unittest.skip(reason="Doesn't support another framework than PyTorch")
1887+
@unittest.skip(reason="Doesn't support returning Numpy arrays")
18881888
def test_np_encode_plus_sent_to_model(self):
18891889
pass
18901890

tests/models/markuplm/test_tokenization_markuplm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2195,7 +2195,7 @@ def test_markuplm_integration_test(self):
21952195
self.assertDictEqual(dict(encoding_p), expected_results)
21962196
self.assertDictEqual(dict(encoding_r), expected_results)
21972197

2198-
@unittest.skip(reason="Doesn't support another framework than PyTorch")
2198+
@unittest.skip(reason="Doesn't support returning Numpy arrays")
21992199
def test_np_encode_plus_sent_to_model(self):
22002200
pass
22012201

tests/models/tapas/test_tokenization_tapas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ def test_full_tokenizer(self):
11501150
self.assertListEqual(column_ids.tolist(), expected_results["column_ids"])
11511151
self.assertListEqual(row_ids.tolist(), expected_results["row_ids"])
11521152

1153-
@unittest.skip(reason="Doesn't support another framework than PyTorch")
1153+
@unittest.skip(reason="Doesn't support returning Numpy arrays")
11541154
def test_np_encode_plus_sent_to_model(self):
11551155
pass
11561156

tests/models/udop/test_tokenization_udop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1773,7 +1773,7 @@ def test_udop_integration_test(self):
17731773
self.assertDictEqual(dict(encoding_p), expected_results)
17741774
self.assertDictEqual(dict(encoding_r), expected_results)
17751775

1776-
@unittest.skip(reason="Doesn't support another framework than PyTorch")
1776+
@unittest.skip(reason="Doesn't support returning Numpy arrays")
17771777
def test_np_encode_plus_sent_to_model(self):
17781778
pass
17791779

tests/pipelines/test_pipelines_common.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def test_load_default_pipelines_pt(self):
564564
# test table in separate test due to more dependencies
565565
continue
566566

567-
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
567+
self.check_default_pipeline(task, set_seed_fn, self.check_models_equal_pt)
568568

569569
# clean-up as much as possible GPU memory occupied by PyTorch
570570
gc.collect()
@@ -576,7 +576,7 @@ def test_load_default_pipelines_pt_table_qa(self):
576576
import torch
577577

578578
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
579-
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
579+
self.check_default_pipeline("table-question-answering", set_seed_fn, self.check_models_equal_pt)
580580

581581
# clean-up as much as possible GPU memory occupied by PyTorch
582582
gc.collect()
@@ -624,17 +624,17 @@ def test_bc_torch_device(self):
624624
self.assertEqual(k1, k2)
625625
self.assertEqual(v1.dtype, v2.dtype)
626626

627-
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
627+
def check_default_pipeline(self, task, set_seed_fn, check_models_equal_fn):
628628
from transformers.pipelines import SUPPORTED_TASKS, pipeline
629629

630630
task_dict = SUPPORTED_TASKS[task]
631631
# test to compare pipeline to manually loading the respective model
632632
model = None
633-
relevant_auto_classes = task_dict[framework]
633+
relevant_auto_classes = task_dict["pt"]
634634

635635
if len(relevant_auto_classes) == 0:
636636
# task has no default
637-
self.skipTest(f"{task} in {framework} has no default")
637+
self.skipTest(f"{task} in pytorch has no default")
638638

639639
# by default use first class
640640
auto_model_cls = relevant_auto_classes[0]
@@ -646,14 +646,14 @@ def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equa
646646
revisions = []
647647
tasks = []
648648
for translation_pair in task_dict["default"]:
649-
model_id, revision = task_dict["default"][translation_pair]["model"][framework]
649+
model_id, revision = task_dict["default"][translation_pair]["model"]
650650

651651
model_ids.append(model_id)
652652
revisions.append(revision)
653653
tasks.append(task + f"_{'_to_'.join(translation_pair)}")
654654
else:
655655
# normal case - non-translation pipeline
656-
model_id, revision = task_dict["default"]["model"][framework]
656+
model_id, revision = task_dict["default"]["model"]
657657

658658
model_ids = [model_id]
659659
revisions = [revision]

0 commit comments

Comments
 (0)