Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W2v2 test require torch #10665

Merged
merged 3 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/test_feature_extraction_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np

from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from transformers.testing_utils import slow
from transformers.testing_utils import require_torch, slow

from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin

Expand Down Expand Up @@ -134,6 +134,7 @@ def _check_zero_mean_unit_variance(input_vector):
_check_zero_mean_unit_variance(input_values[2])

@slow
@require_torch
def test_pretrained_checkpoints_are_set_correctly(self):
# this test makes sure that models that are using
# group norm don't have their feature extractor return the
Expand Down
3 changes: 2 additions & 1 deletion tests/test_tokenization_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
Wav2Vec2Tokenizer,
)
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.testing_utils import slow
from transformers.testing_utils import require_torch, slow

from .test_tokenization_common import TokenizerTesterMixin

Expand Down Expand Up @@ -340,6 +340,7 @@ def test_return_attention_mask(self):
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200])

@slow
@require_torch
def test_pretrained_checkpoints_are_set_correctly(self):
# this test makes sure that models that are using
# group norm don't have their tokenizer return the
Expand Down