From 5840911f6a1b0e0889f88a6a157b6b8ae8fac820 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Feb 2022 17:33:33 +0100 Subject: [PATCH] [Wav2Vec2ProcessorWithLM] Fix auto processor with lm (#15683) --- .../processing_wav2vec2_with_lm.py | 2 ++ tests/test_processor_wav2vec2_with_lm.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index 1c0fb5d0bfe109..ca59a948ff3bed 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -138,6 +138,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): else: # BeamSearchDecoderCTC has no auto class kwargs.pop("_from_auto", None) + # snapshot_download has no `trust_remote_code` flag + kwargs.pop("trust_remote_code", None) # make sure that only relevant filenames are downloaded language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*") diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py index b3f4cb3cc77c0d..1d31b37c47d997 100644 --- a/tests/test_processor_wav2vec2_with_lm.py +++ b/tests/test_processor_wav2vec2_with_lm.py @@ -22,6 +22,7 @@ import numpy as np +from transformers import AutoProcessor from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES @@ -330,3 +331,22 @@ def test_decoder_local_files(self): # test that both decoder form hub and local files in cache are the same self.assertListEqual(local_decoder_files, expected_decoder_files) + + def test_processor_from_auto_processor(self): + processor_wav2vec2 = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm") + processor_auto = AutoProcessor.from_pretrained("hf-internal-testing/processor_with_lm") + + raw_speech = floats_list((3, 1000)) + + input_wav2vec2 = processor_wav2vec2(raw_speech, return_tensors="np") + input_auto = processor_auto(raw_speech, return_tensors="np") + + for key in input_wav2vec2.keys(): + self.assertAlmostEqual(input_wav2vec2[key].sum(), input_auto[key].sum(), delta=1e-2) + + logits = self._get_dummy_logits() + + decoded_wav2vec2 = processor_wav2vec2.batch_decode(logits) + decoded_auto = processor_auto.batch_decode(logits) + + self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)