From 0f65a8728aaee0012f5c22cc21f420717cdd7a02 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 May 2021 11:53:30 +0200 Subject: [PATCH] [Wav2Vec2] Fix convert (#11562) * push * small change * correct other typo --- ...rt_wav2vec2_original_pytorch_checkpoint_to_pytorch.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py index cc902ee3bc9171..2ba66c70be89a4 100644 --- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py @@ -178,9 +178,11 @@ def convert_wav2vec2_checkpoint( if dict_path: target_dict = Dictionary.load(dict_path) - config.bos_token_id = target_dict.bos_index + # important change bos & pad token id since CTC symbol is and + # not as in fairseq + config.bos_token_id = target_dict.pad_index + config.pad_token_id = target_dict.bos_index config.eos_token_id = target_dict.eos_index - config.pad_token_id = target_dict.pad_index config.vocab_size = len(target_dict.symbols) vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") if not os.path.isdir(pytorch_dump_folder_path): @@ -214,9 +216,8 @@ def convert_wav2vec2_checkpoint( hf_wav2vec = Wav2Vec2Model(config) if is_finetuned: - model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( - [checkpoint_path], arg_overrides={"data": dict_path} + [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} ) else: model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])