From f4dcede2f07c10837f3c05b61012dafbe7b52240 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 11 Mar 2021 12:56:12 -0500 Subject: [PATCH] W2v2 test require torch (#10665) * Adds a @require_torch to a test that requires it * Tokenizer too * Style --- tests/test_feature_extraction_wav2vec2.py | 3 ++- tests/test_tokenization_wav2vec2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_feature_extraction_wav2vec2.py b/tests/test_feature_extraction_wav2vec2.py index 771974a3982179..d55d951ee3ec8d 100644 --- a/tests/test_feature_extraction_wav2vec2.py +++ b/tests/test_feature_extraction_wav2vec2.py @@ -21,7 +21,7 @@ import numpy as np from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor -from transformers.testing_utils import slow +from transformers.testing_utils import require_torch, slow from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin @@ -134,6 +134,7 @@ def _check_zero_mean_unit_variance(input_vector): _check_zero_mean_unit_variance(input_values[2]) @slow + @require_torch def test_pretrained_checkpoints_are_set_correctly(self): # this test makes sure that models that are using # group norm don't have their feature extractor return the diff --git a/tests/test_tokenization_wav2vec2.py b/tests/test_tokenization_wav2vec2.py index f7a5e4da164c1d..002bf4b2256a0a 100644 --- a/tests/test_tokenization_wav2vec2.py +++ b/tests/test_tokenization_wav2vec2.py @@ -30,7 +30,7 @@ Wav2Vec2Tokenizer, ) from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES -from transformers.testing_utils import slow +from transformers.testing_utils import require_torch, slow from .test_tokenization_common import TokenizerTesterMixin @@ -340,6 +340,7 @@ def test_return_attention_mask(self): self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200]) @slow + @require_torch def test_pretrained_checkpoints_are_set_correctly(self): # this test makes sure that models that are using # group norm don't have their tokenizer return the