From f71c9ccf592730fd9c733da56915569e2e8753aa Mon Sep 17 00:00:00 2001 From: YQ Date: Mon, 23 Oct 2023 18:33:05 +0800 Subject: [PATCH 01/38] fix logit-to-multi-hot conversion in example (#26936) * fix logit to multi-hot converstion * add comments * typo --- examples/pytorch/text-classification/run_classification.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index 3033a61404e839..1bc4dbe5fa3712 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -655,7 +655,7 @@ def compute_metrics(p: EvalPrediction): preds = np.squeeze(preds) result = metric.compute(predictions=preds, references=p.label_ids) elif is_multi_label: - preds = np.array([np.where(p > 0.5, 1, 0) for p in preds]) + preds = np.array([np.where(p > 0, 1, 0) for p in preds]) # convert logits to multi-hot encoding # Micro F1 is commonly used in multi-label classification result = metric.compute(predictions=preds, references=p.label_ids, average="micro") else: @@ -721,7 +721,10 @@ def compute_metrics(p: EvalPrediction): if is_regression: predictions = np.squeeze(predictions) elif is_multi_label: - predictions = np.array([np.where(p > 0.5, 1, 0) for p in predictions]) + # Convert logits to multi-hot encoding. We compare the logits to 0 instead of 0.5, because the sigmoid is not applied. + # You can also pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer + # and set p > 0.5 below (less efficient in this case) + predictions = np.array([np.where(p > 0, 1, 0) for p in predictions]) else: predictions = np.argmax(predictions, axis=1) output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt") From 700329493dba077654cd771044ec303a265a8d79 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 23 Oct 2023 03:34:21 -0700 Subject: [PATCH 02/38] Limit to inferior fsspec version (#27010) Pin fsspec --- .circleci/create_circleci_config.py | 1 + setup.py | 1 + src/transformers/dependency_versions_table.py | 1 + 3 files changed, 3 insertions(+) diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 0829207a185d0d..6dda676126c1db 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -127,6 +127,7 @@ def to_dict(self): }, ] steps.extend([{"run": l} for l in self.install_steps]) + steps.extend([{"run": 'pip install "fsspec>=2023.5.0,<2023.10.0"'}]) steps.extend([{"run": "pip install pytest-subtests"}]) steps.append( { diff --git a/setup.py b/setup.py index c737058cb8695e..6f43fae9ca574f 100644 --- a/setup.py +++ b/setup.py @@ -113,6 +113,7 @@ "fastapi", "filelock", "flax>=0.4.1,<=0.7.0", + "fsspec<2023.10.0", "ftfy", "fugashi>=1.0", "GitPython<3.1.19", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 4b3e32bc04b32f..1dbedc3ea6de2b 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -20,6 +20,7 @@ "fastapi": "fastapi", "filelock": "filelock", "flax": "flax>=0.4.1,<=0.7.0", + "fsspec": "fsspec<2023.10.0", "ftfy": "ftfy", "fugashi": "fugashi>=1.0", "GitPython": "GitPython<3.1.19", From 45425660d0911c35d0dd65555f798567de1d920e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gema=20Parre=C3=B1o?= Date: Mon, 23 Oct 2023 12:51:35 +0200 Subject: [PATCH 03/38] python falcon doc-string example typo (#26995) git python falcon typo --- src/transformers/models/falcon/configuration_falcon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/falcon/configuration_falcon.py b/src/transformers/models/falcon/configuration_falcon.py index fce21b146cf97f..e43340275771c6 100644 --- a/src/transformers/models/falcon/configuration_falcon.py +++ b/src/transformers/models/falcon/configuration_falcon.py @@ -92,7 +92,7 @@ class FalconConfig(PretrainedConfig): Example: - ```pytho + ```python >>> from transformers import FalconModel, FalconConfig >>> # Initializing a small (2-layer) Falcon configuration From ef978d0a7bb6455eff5c126cd6e4f10de0158004 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 23 Oct 2023 12:52:05 +0200 Subject: [PATCH 04/38] skip two tests (#27013) * skip two tests * skip torch as well * fixup --- .../blenderbot_small/test_modeling_blenderbot_small.py | 4 +++- .../test_modeling_tf_blenderbot_small.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index 249a8a799a8307..2397b6fee9727d 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -244,7 +244,9 @@ def is_pipeline_test_to_skip( ): if pipeline_test_casse_name == "TextGenerationPipelineTests": return True - + # TODO @Rocketnight1 to fix + if pipeline_test_casse_name == "ConversationalPipelineTests": + return True return False def setUp(self): diff --git a/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py index 021f171789e4cb..6fde705082d156 100644 --- a/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_tf_blenderbot_small.py @@ -209,6 +209,15 @@ def test_decoder_model_past_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs) + # TODO: Fix the failed tests when this model gets more usage + def is_pipeline_test_to_skip( + self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name + ): + # TODO @Rocketnight1 to fix + if pipeline_test_casse_name == "ConversationalPipelineTests": + return True + return False + @require_tokenizers @require_tf From d33d3131920a1681f2194f3ae2a39ced9f073d64 Mon Sep 17 00:00:00 2001 From: Omar Sanseviero Date: Mon, 23 Oct 2023 14:19:59 +0200 Subject: [PATCH 05/38] Nits in Llama2 docstring (#26996) Update llama2.md --- docs/source/en/model_doc/llama2.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/llama2.md b/docs/source/en/model_doc/llama2.md index 11cb26f0dcf6ee..0ff1e38f16a2a0 100644 --- a/docs/source/en/model_doc/llama2.md +++ b/docs/source/en/model_doc/llama2.md @@ -28,12 +28,12 @@ Checkout all Llama2 models [here](https://huggingface.co/models?search=llama2) -The `Llama2` models were trained using `bfloat16`, but the original inference uses `float16. The checkpoints uploaded on the hub use `torch_dtype = 'float16'` which will be +The `Llama2` models were trained using `bfloat16`, but the original inference uses `float16`. The checkpoints uploaded on the Hub use `torch_dtype = 'float16'`, which will be used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`. -The `dtype` of the online weights is mostly irrelevant, unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online) then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`) and finally, if there is a `torch_dtype` provided in the config, it will be used. +The `dtype` of the online weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online), then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`), and finally, if there is a `torch_dtype` provided in the config, it will be used. -Training the model in `float16` is not recommended and known to produce `nan`, as such the model should be trained in `bfloat16`. +Training the model in `float16` is not recommended and is known to produce `nan`; as such, the model should be trained in `bfloat16`. From 50d0cf4f6b86afa7b2439b0a7d7384740fae38d1 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:25:48 +0200 Subject: [PATCH 06/38] Change default `max_shard_size` to smaller value (#26942) * Update modeling_utils.py * fixup * let's change it to 5GB * fix --- src/transformers/modeling_utils.py | 6 ++++-- src/transformers/utils/hub.py | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d567b9438a0d00..0317695f2096de 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1867,7 +1867,7 @@ def save_pretrained( state_dict: Optional[dict] = None, save_function: Callable = torch.save, push_to_hub: bool = False, - max_shard_size: Union[int, str] = "10GB", + max_shard_size: Union[int, str] = "5GB", safe_serialization: bool = False, variant: Optional[str] = None, token: Optional[Union[str, bool]] = None, @@ -1896,9 +1896,11 @@ def save_pretrained( Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`): The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + We default it to 5GB in order for models to be able to run easily on free-tier google colab instances + without CPU OOM issues. diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index b900311003b85a..a3ed744f467daf 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -790,7 +790,7 @@ def push_to_hub( commit_message: Optional[str] = None, private: Optional[bool] = None, token: Optional[Union[bool, str]] = None, - max_shard_size: Optional[Union[int, str]] = "10GB", + max_shard_size: Optional[Union[int, str]] = "5GB", create_pr: bool = False, safe_serialization: bool = False, revision: str = None, @@ -814,10 +814,11 @@ def push_to_hub( The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url` is not specified. - max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`): Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed - by a unit (like `"5MB"`). + by a unit (like `"5MB"`). We default it to `"5GB"` so that users can easily load models on free-tier + Google Colab instances without any CPU OOM issues. create_pr (`bool`, *optional*, defaults to `False`): Whether or not to create a PR with the uploaded files or directly commit. safe_serialization (`bool`, *optional*, defaults to `False`): From cb45f71c4dfb28613f8348716de821df7db68799 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:49:48 +0200 Subject: [PATCH 07/38] Add Seamless M4T model (#25693) * first raw commit * still POC * tentative convert script * almost working speech encoder conversion scripts * intermediate code for encoder/decoders * add modeling code * first version of speech encoder * make style * add new adapter layer architecture * add adapter block * add first tentative config * add working speech encoder conversion * base model convert works now * make style * remove unnecessary classes * remove unecessary functions * add modeling code speech encoder * rework logics * forward pass of sub components work * add modeling codes * some config modifs and modeling code modifs * save WIP * new edits * same output speech encoder * correct attention mask * correct attention mask * fix generation * new generation logics * erase comments * make style * fix typo * add some descriptions * new state * clean imports * add tests * make style * make beam search and num_return_sequences>1 works * correct edge case issue * correct SeamlessM4TConformerSamePadLayer copied from * replace ACT2FN relu by nn.relu * remove unecessary return variable * move back a class * change name conformer_attention_mask ->conv_attention_mask * better nit code * add some Copied from statements * small nits * small nit in dict.get * rename t2u model -> conditionalgeneration * ongoing refactoring of structure * update models architecture * remove SeamlessM4TMultiModal classes * add tests * adapt tests * some non-working code for vocoder * add seamlessM4T vocoder * remove buggy line * fix some hifigan related bugs * remove hifigan specifc config * change * add WIP tokenization * add seamlessM4T working tokenzier * update tokenization * add tentative feature extractor * Update converting script * update working FE * refactor input_values -> input_features * update FE * changes in generation, tokenizer and modeling * make style and add t2u_decoder_input_ids * add intermediate outputs for ToSpeech models * add vocoder to speech models * update valueerror * update FE with languages * add vocoder convert * update config docstrings and names * update generation code and configuration * remove todos and update config.pad_token_id to generation_config.pad_token_id * move block vocoder * remove unecessary code and uniformize tospeech code * add feature extractor import * make style and fix some copies from * correct consistency + make fix-copies * add processor code * remove comments * add fast tokenizer support * correct pad_token_id in M4TModel * correct config * update tests and codes + make style * make some suggested correstion - correct comments and change naming * rename some attributes * rename some attributes * remove unecessary sequential * remove option to use dur predictor * nit * refactor hifigan * replace normalize_mean and normalize_var with do_normalize + save lang ids to generation config * add tests * change tgt_lang logic * update generation ToSpeech * add support import SeamlessM4TProcessor * fix generate * make tests * update integration tests, add option to only return text and update tokenizer fast * fix wrong function call * update import and convert script * update integration tests + update repo id * correct paths and add first test * update how new attention masks are computed * update tests * take first care of batching in vocoder code * add batching with the vocoder * add waveform lengths to model outputs * make style * add generate kwargs + forward kwargs of M4TModel * add docstrings forward methods * reformate docstrings * add docstrings t2u model * add another round of modeling docstrings + reformate speaker_id -> spkr_id * make style * fix check_repo * make style * add seamlessm4t to toctree * correct check_config_attributes * write config docstrings + some modifs * make style * add docstrings tokenizer * add docstrings to processor, fe and tokenizers * make style * write first version of model docs * fix FE + correct FE test * fix tokenizer + add correct integration tests * fix most tokenization tests * make style * correct most processor test * add generation tests and fix num_return_sequences > 1 * correct integration tests -still one left * make style * correct position embedding * change numbeams to 1 * refactor some modeling code and correct one test * make style * correct typo * refactor intermediate fnn * refactor feedforward conformer * make style * remove comments * make style * fix tokenizer tests * make style * correct processor tests * make style * correct S2TT integration * Apply suggestions from Sanchit code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * correct typo * replace torch.nn->nn + make style * change Output naming (waveforms -> waveform) and ordering * nit renaming and formating * remove return None when not necessary * refactor SeamlessM4TConformerFeedForward * nit typo * remove almost copied from comments * add a copied from comment and remove an unecessary dropout * remove inputs_embeds from speechencoder * remove backward compatibiliy function * reformate class docstrings for a few components * remove unecessary methods * split over 2 lines smthg hard to read * make style * replace two steps offset by one step as suggested * nice typo * move warnings * remove useless lines from processor * make generation non-standard test more robusts * remove torch.inference_mode from tests * split integration tests * enrich md * rename control_symbol_vocoder_offset->vocoder_offset * clean convert file * remove tgt_lang and src_lang from FE * change generate docstring of ToText models * update generate docstring of tospeech models * unify how to deal withtext_decoder_input_ids * add default spkr_id * unify tgt_lang for t2u_model * simplify tgt_lang verification * remove a todo * change config docstring * make style * simplify t2u_tgt_lang_id * make style * enrich/correct comments * enrich .md * correct typo in docstrings * add torchaudio dependency * update tokenizer * make style and fix copies * modify SeamlessM4TConverter with new tokenizer behaviour * make style * correct small typo docs * fix import * update docs and add requirement to tests * add convert_fairseq2_to_hf in utils/not_doctested.txt * update FE * fix imports and make style * remove torchaudio in FE test * add seamless_m4t.md to utils/not_doctested.txt * nits and change the way docstring dataset is loaded * move checkpoints from ylacombe/ to facebook/ orga * refactor warning/error to be in the 119 line width limit * round overly precised floats * add stereo audio behaviour * refactor .md and make style * enrich docs with more precised architecture description * readd undocumented models * make fix-copies * apply some suggestions * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * correct bug from previous commit * refactor a parameter allowing to clean the code + some small nits * clean tokenizer * make style and fix * make style * clean tokenizers arguments * add precisions for some tests * move docs from not_tested to slow * modify tokenizer according to last comments * add copied from statements in tests * correct convert script * correct parameter docstring style * correct tokenization * correct multi gpus * make style * clean modeling code * make style * add copied from statements * add copied statements * add support with ASR pipeline * remove file added inadvertently * fix docstrings seamlessM4TModel * add seamlessM4TConfig to OBJECTS_TO_IGNORE due of unconventional markdown * add seamlessm4t to assisted generation ignored models --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- README.md | 1 + README_es.md | 1 + README_hd.md | 1 + README_ja.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/seamless_m4t.md | 218 + docs/source/en/tasks/summarization.md | 2 +- docs/source/en/tasks/translation.md | 2 +- src/transformers/__init__.py | 49 +- src/transformers/convert_slow_tokenizer.py | 26 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 7 + .../models/seamless_m4t/__init__.py | 111 + .../configuration_seamless_m4t.py | 417 ++ .../seamless_m4t/convert_fairseq2_to_hf.py | 410 ++ .../feature_extraction_seamless_m4t.py | 305 ++ .../seamless_m4t/modeling_seamless_m4t.py | 4607 +++++++++++++++++ .../seamless_m4t/processing_seamless_m4t.py | 116 + .../seamless_m4t/tokenization_seamless_m4t.py | 565 ++ .../tokenization_seamless_m4t_fast.py | 459 ++ src/transformers/utils/dummy_pt_objects.py | 73 + .../utils/dummy_sentencepiece_objects.py | 7 + .../utils/dummy_tokenizers_objects.py | 7 + tests/generation/test_utils.py | 2 +- tests/models/seamless_m4t/__init__.py | 0 .../test_feature_extraction_seamless_m4t.py | 223 + .../test_modeling_seamless_m4t.py | 1110 ++++ .../test_processor_seamless_m4t.py | 126 + .../test_tokenization_seamless_m4t.py | 672 +++ utils/check_config_attributes.py | 12 + utils/check_docstrings.py | 1 + utils/check_repo.py | 7 + utils/not_doctested.txt | 1 + utils/slow_documentation_tests.txt | 1 + 42 files changed, 9551 insertions(+), 5 deletions(-) create mode 100644 docs/source/en/model_doc/seamless_m4t.md create mode 100644 src/transformers/models/seamless_m4t/__init__.py create mode 100644 src/transformers/models/seamless_m4t/configuration_seamless_m4t.py create mode 100644 src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py create mode 100644 src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py create mode 100755 src/transformers/models/seamless_m4t/modeling_seamless_m4t.py create mode 100644 src/transformers/models/seamless_m4t/processing_seamless_m4t.py create mode 100644 src/transformers/models/seamless_m4t/tokenization_seamless_m4t.py create mode 100644 src/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py create mode 100644 tests/models/seamless_m4t/__init__.py create mode 100644 tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py create mode 100644 tests/models/seamless_m4t/test_modeling_seamless_m4t.py create mode 100644 tests/models/seamless_m4t/test_processor_seamless_m4t.py create mode 100644 tests/models/seamless_m4t/test_tokenization_seamless_m4t.py diff --git a/README.md b/README.md index 3e3d261226dd52..758675e29716b1 100644 --- a/README.md +++ b/README.md @@ -459,6 +459,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[RWKV](https://huggingface.co/docs/transformers/model_doc/rwkv)** (from Bo Peng), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng. +1. **[SeamlessM4T](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t)** (from Meta AI) released with the paper [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/README_es.md b/README_es.md index dee98fa4aee0ee..cadc5ae019e6cb 100644 --- a/README_es.md +++ b/README_es.md @@ -434,6 +434,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[RWKV](https://huggingface.co/docs/transformers/model_doc/rwkv)** (from Bo Peng) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng. +1. **[SeamlessM4T](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t)** (from Meta AI) released with the paper [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/README_hd.md b/README_hd.md index 7cc0f199b989fb..ba837e4a1860d6 100644 --- a/README_hd.md +++ b/README_hd.md @@ -408,6 +408,7 @@ conda install -c huggingface transformers 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (झुईई टेक्नोलॉजी से), साथ में पेपर [रोफॉर्मर: रोटरी पोजिशन एंबेडिंग के साथ एन्हांस्ड ट्रांसफॉर्मर] (https://arxiv.org/pdf/2104.09864v1.pdf) जियानलिन सु और यू लू और शेंगफेंग पैन और बो वेन और युनफेंग लियू द्वारा प्रकाशित। 1. **[RWKV](https://huggingface.co/docs/transformers/model_doc/rwkv)** (Bo Peng से) Bo Peng. द्वाराअनुसंधान पत्र [this repo](https://github.com/BlinkDL/RWKV-LM) के साथ जारी किया गया +1. **[SeamlessM4T](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t)** (from Meta AI) released with the paper [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (Meta AI से) Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. द्वाराअनुसंधान पत्र [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) के साथ जारी किया गया 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP से) साथ देने वाला पेपर [भाषण पहचान के लिए अनसुपरवाइज्ड प्री-ट्रेनिंग में परफॉर्मेंस-एफिशिएंसी ट्रेड-ऑफ्स](https ://arxiv.org/abs/2109.06870) फेलिक्स वू, क्वांगयुन किम, जिंग पैन, क्यू हान, किलियन क्यू. वेनबर्गर, योव आर्टज़ी द्वारा। diff --git a/README_ja.md b/README_ja.md index 0944a9a36b7d9e..6f25f1261b70e9 100644 --- a/README_ja.md +++ b/README_ja.md @@ -468,6 +468,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (WeChatAI から) HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou から公開された研究論文: [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (ZhuiyiTechnology から), Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu から公開された研究論文: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) 1. **[RWKV](https://huggingface.co/docs/transformers/model_doc/rwkv)** (Bo Peng から) Bo Peng. から公開された研究論文 [this repo](https://github.com/BlinkDL/RWKV-LM) +1. **[SeamlessM4T](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t)** (from Meta AI) released with the paper [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (NVIDIA から) Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo から公開された研究論文: [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (Meta AI から) Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. から公開された研究論文 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP から) Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi から公開された研究論文: [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) diff --git a/README_ko.md b/README_ko.md index 2e325be61c30bc..77829ffbf1cb99 100644 --- a/README_ko.md +++ b/README_ko.md @@ -383,6 +383,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (WeChatAI 에서) HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou 의 [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 논문과 함께 발표했습니다. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (ZhuiyiTechnology 에서) Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 의 a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 논문과 함께 발표했습니다. 1. **[RWKV](https://huggingface.co/docs/transformers/model_doc/rwkv)** (Bo Peng 에서 제공)은 Bo Peng.의 [this repo](https://github.com/BlinkDL/RWKV-LM)논문과 함께 발표했습니다. +1. **[SeamlessM4T](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t)** (from Meta AI) released with the paper [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (NVIDIA 에서) Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 의 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 논문과 함께 발표했습니다. 1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (Meta AI 에서 제공)은 Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.의 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf)논문과 함께 발표했습니다. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP 에서) Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 의 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 53bef31231c4cd..4198883bff33fc 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -407,6 +407,7 @@ conda install -c huggingface transformers 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (来自 WeChatAI), 伴随论文 [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 由 HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou 发布。 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (来自 ZhuiyiTechnology), 伴随论文 [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 由 Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 发布。 1. **[RWKV](https://huggingface.co/docs/transformers/model_doc/rwkv)** (来自 Bo Peng) 伴随论文 [this repo](https://github.com/BlinkDL/RWKV-LM) 由 Bo Peng 发布。 +1. **[SeamlessM4T](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t)** (from Meta AI) released with the paper [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (来自 NVIDIA) 伴随论文 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 由 Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 发布。 1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (来自 Meta AI) 伴随论文 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) 由 Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick 发布。 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (来自 ASAPP) 伴随论文 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 由 Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 5ba42b77ef9739..0f9782ef2112cb 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -419,6 +419,7 @@ conda install -c huggingface transformers 1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[RWKV](https://huggingface.co/docs/transformers/model_doc/rwkv)** (from Bo Peng) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng. +1. **[SeamlessM4T](https://huggingface.co/docs/transformers/main/model_doc/seamless_m4t)** (from Meta AI) released with the paper [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. 1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. 1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 11fa956db89f22..84e306a786f9c6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -614,6 +614,8 @@ title: MusicGen - local: model_doc/pop2piano title: Pop2Piano + - local: model_doc/seamless_m4t + title: Seamless-M4T - local: model_doc/sew title: SEW - local: model_doc/sew-d diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 97fdcf55e1ace7..9a9692f35d9bbd 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -236,6 +236,7 @@ Flax), PyTorch, and/or TensorFlow. | [RoFormer](model_doc/roformer) | ✅ | ✅ | ✅ | | [RWKV](model_doc/rwkv) | ✅ | ❌ | ❌ | | [SAM](model_doc/sam) | ✅ | ✅ | ❌ | +| [SeamlessM4T](model_doc/seamless_m4t) | ✅ | ❌ | ❌ | | [SegFormer](model_doc/segformer) | ✅ | ✅ | ❌ | | [SEW](model_doc/sew) | ✅ | ❌ | ❌ | | [SEW-D](model_doc/sew-d) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/seamless_m4t.md b/docs/source/en/model_doc/seamless_m4t.md new file mode 100644 index 00000000000000..3a67a58a08a967 --- /dev/null +++ b/docs/source/en/model_doc/seamless_m4t.md @@ -0,0 +1,218 @@ + + +# SeamlessM4T + +## Overview + +The SeamlessM4T model was proposed in [SeamlessM4T — Massively Multilingual & Multimodal Machine Translation](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) by the Seamless Communication team from Meta AI. + +SeamlessM4T is a collection of models designed to provide high quality translation, allowing people from different linguistic communities to communicate effortlessly through speech and text. + +SeamlessM4T enables multiple tasks without relying on separate models: + +- Speech-to-speech translation (S2ST) +- Speech-to-text translation (S2TT) +- Text-to-speech translation (T2ST) +- Text-to-text translation (T2TT) +- Automatic speech recognition (ASR) + +[`SeamlessM4TModel`] can perform all the above tasks, but each task also has its own dedicated sub-model. + +The abstract from the paper is the following: + +*What does it take to create the Babel Fish, a tool that can help individuals translate speech between any two languages? While recent breakthroughs in text-based models have pushed machine translation coverage beyond 200 languages, unified speech-to-speech translation models have yet to achieve similar strides. More specifically, conventional speech-to-speech translation systems rely on cascaded systems that perform translation progressively, putting high-performing unified systems out of reach. To address these gaps, we introduce SeamlessM4T, a single model that supports speech-to-speech translation, speech-to-text translation, text-to-speech translation, text-to-text translation, and automatic speech recognition for up to 100 languages. To build this, we used 1 million hours of open speech audio data to learn self-supervised speech representations with w2v-BERT 2.0. Subsequently, we created a multimodal corpus of automatically aligned speech translations. Filtered and combined with human-labeled and pseudo-labeled data, we developed the first multilingual system capable of translating from and into English for both speech and text. On FLEURS, SeamlessM4T sets a new standard for translations into multiple target languages, achieving an improvement of 20% BLEU over the previous SOTA in direct speech-to-text translation. Compared to strong cascaded models, SeamlessM4T improves the quality of into-English translation by 1.3 BLEU points in speech-to-text and by 2.6 ASR-BLEU points in speech-to-speech. Tested for robustness, our system performs better against background noises and speaker variations in speech-to-text tasks compared to the current SOTA model. Critically, we evaluated SeamlessM4T on gender bias and added toxicity to assess translation safety. Finally, all contributions in this work are open-sourced and accessible at https://github.com/facebookresearch/seamless_communication* + +## Usage + +First, load the processor and a checkpoint of the model: + +```python +>>> from transformers import AutoProcessor, SeamlessM4TModel + +>>> processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium") +>>> model = SeamlessM4TModel.from_pretrained("facebook/hf-seamless-m4t-medium") +``` + +You can seamlessly use this model on text or on audio, to generated either translated text or translated audio. + +Here is how to use the processor to process text and audio: + +```python +>>> # let's load an audio sample from an Arabic speech corpus +>>> from datasets import load_dataset +>>> dataset = load_dataset("arabic_speech_corpus", split="test", streaming=True) +>>> audio_sample = next(iter(dataset))["audio"] + +>>> # now, process it +>>> audio_inputs = processor(audios=audio_sample["array"], return_tensors="pt") + +>>> # now, process some English test as well +>>> text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt") +``` + + +### Speech + +[`SeamlessM4TModel`] can *seamlessly* generate text or speech with few or no changes. Let's target Russian voice translation: + +```python +>>> audio_array_from_text = model.generate(**text_inputs, tgt_lang="rus")[0].cpu().numpy().squeeze() +>>> audio_array_from_audio = model.generate(**audio_inputs, tgt_lang="rus")[0].cpu().numpy().squeeze() +``` + +With basically the same code, I've translated English text and Arabic speech to Russian speech samples. + +### Text + +Similarly, you can generate translated text from audio files or from text with the same model. You only have to pass `generate_speech=False` to [`SeamlessM4TModel.generate`]. +This time, let's translate to French. + +```python +>>> # from audio +>>> output_tokens = model.generate(**audio_inputs, tgt_lang="fra", generate_speech=False) +>>> translated_text_from_audio = processor.decode(output_tokens[0].tolist()[0], skip_special_tokens=True) + +>>> # from text +>>> output_tokens = model.generate(**text_inputs, tgt_lang="fra", generate_speech=False) +>>> translated_text_from_text = processor.decode(output_tokens[0].tolist()[0], skip_special_tokens=True) +``` + +### Tips + + +#### 1. Use dedicated models + +[`SeamlessM4TModel`] is transformers top level model to generate speech and text, but you can also use dedicated models that perform the task without additional components, thus reducing the memory footprint. +For example, you can replace the audio-to-audio generation snippet with the model dedicated to the S2ST task, the rest is exactly the same code: + +```python +>>> from transformers import SeamlessM4TForSpeechToSpeech +>>> model = SeamlessM4TForSpeechToSpeech.from_pretrained("facebook/hf-seamless-m4t-medium") +``` + +Or you can replace the text-to-text generation snippet with the model dedicated to the T2TT task, you only have to remove `generate_speech=False`. + +```python +>>> from transformers import SeamlessM4TForTextToText +>>> model = SeamlessM4TForTextToText.from_pretrained("facebook/hf-seamless-m4t-medium") +``` + +Feel free to try out [`SeamlessM4TForSpeechToText`] and [`SeamlessM4TForTextToSpeech`] as well. + +#### 2. Change the speaker identity + +You have the possibility to change the speaker used for speech synthesis with the `spkr_id` argument. Some `spkr_id` works better than other for some languages! + +#### 3. Change the generation strategy + +You can use different [generation strategies](./generation_strategies) for speech and text generation, e.g `.generate(input_ids=input_ids, text_num_beams=4, speech_do_sample=True)` which will successively perform beam-search decoding on the text model, and multinomial sampling on the speech model. + +#### 4. Generate speech and text at the same time + +Use `return_intermediate_token_ids=True` with [`SeamlessM4TModel`] to return both speech and text ! + +## Model architecture + + +SeamlessM4T features a versatile architecture that smoothly handles the sequential generation of text and speech. This setup comprises two sequence-to-sequence (seq2seq) models. The first model translates the input modality into translated text, while the second model generates speech tokens, known as "unit tokens," from the translated text. + +Each modality has its own dedicated encoder with a unique architecture. Additionally, for speech output, a vocoder inspired by the [HiFi-GAN](https://arxiv.org/abs/2010.05646) architecture is placed on top of the second seq2seq model. + +Here's how the generation process works: + +- Input text or speech is processed through its specific encoder. +- A decoder creates text tokens in the desired language. +- If speech generation is required, the second seq2seq model, following a standard encoder-decoder structure, generates unit tokens. +- These unit tokens are then passed through the final vocoder to produce the actual speech. + + +This model was contributed by [ylacombe](https://huggingface.co/ylacombe). The original code can be found [here](https://github.com/facebookresearch/seamless_communication). + +## SeamlessM4TModel + +[[autodoc]] SeamlessM4TModel + - generate + + +## SeamlessM4TForTextToSpeech + +[[autodoc]] SeamlessM4TForTextToSpeech + - generate + + +## SeamlessM4TForSpeechToSpeech + +[[autodoc]] SeamlessM4TForSpeechToSpeech + - generate + + +## SeamlessM4TForTextToText + +[[autodoc]] transformers.SeamlessM4TForTextToText + - forward + - generate + +## SeamlessM4TForSpeechToText + +[[autodoc]] transformers.SeamlessM4TForSpeechToText + - forward + - generate + +## SeamlessM4TConfig + +[[autodoc]] SeamlessM4TConfig + + +## SeamlessM4TTokenizer + +[[autodoc]] SeamlessM4TTokenizer + - __call__ + - build_inputs_with_special_tokens + - get_special_tokens_mask + - create_token_type_ids_from_sequences + - save_vocabulary + + +## SeamlessM4TTokenizerFast + +[[autodoc]] SeamlessM4TTokenizerFast + - __call__ + +## SeamlessM4TFeatureExtractor + +[[autodoc]] SeamlessM4TFeatureExtractor + - __call__ + +## SeamlessM4TProcessor + +[[autodoc]] SeamlessM4TProcessor + - __call__ + +## SeamlessM4TCodeHifiGan + +[[autodoc]] SeamlessM4TCodeHifiGan + + +## SeamlessM4THifiGan + +[[autodoc]] SeamlessM4THifiGan + +## SeamlessM4TTextToUnitModel + +[[autodoc]] SeamlessM4TTextToUnitModel + +## SeamlessM4TTextToUnitForConditionalGeneration + +[[autodoc]] SeamlessM4TTextToUnitForConditionalGeneration + + diff --git a/docs/source/en/tasks/summarization.md b/docs/source/en/tasks/summarization.md index ecdf37ce6efbba..a7d6a446c0c3e4 100644 --- a/docs/source/en/tasks/summarization.md +++ b/docs/source/en/tasks/summarization.md @@ -35,7 +35,7 @@ The task illustrated in this tutorial is supported by the following model archit -[BART](../model_doc/bart), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [Encoder decoder](../model_doc/encoder-decoder), [FairSeq Machine-Translation](../model_doc/fsmt), [GPTSAN-japanese](../model_doc/gptsan-japanese), [LED](../model_doc/led), [LongT5](../model_doc/longt5), [M2M100](../model_doc/m2m_100), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [NLLB](../model_doc/nllb), [NLLB-MOE](../model_doc/nllb-moe), [Pegasus](../model_doc/pegasus), [PEGASUS-X](../model_doc/pegasus_x), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [SwitchTransformers](../model_doc/switch_transformers), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM-ProphetNet](../model_doc/xlm-prophetnet) +[BART](../model_doc/bart), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [Encoder decoder](../model_doc/encoder-decoder), [FairSeq Machine-Translation](../model_doc/fsmt), [GPTSAN-japanese](../model_doc/gptsan-japanese), [LED](../model_doc/led), [LongT5](../model_doc/longt5), [M2M100](../model_doc/m2m_100), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [NLLB](../model_doc/nllb), [NLLB-MOE](../model_doc/nllb-moe), [Pegasus](../model_doc/pegasus), [PEGASUS-X](../model_doc/pegasus_x), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [SeamlessM4T](../model_doc/seamless_m4t), [SwitchTransformers](../model_doc/switch_transformers), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM-ProphetNet](../model_doc/xlm-prophetnet) diff --git a/docs/source/en/tasks/translation.md b/docs/source/en/tasks/translation.md index d5394caef838a2..c17e3db23f2e8a 100644 --- a/docs/source/en/tasks/translation.md +++ b/docs/source/en/tasks/translation.md @@ -32,7 +32,7 @@ The task illustrated in this tutorial is supported by the following model archit -[BART](../model_doc/bart), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [Encoder decoder](../model_doc/encoder-decoder), [FairSeq Machine-Translation](../model_doc/fsmt), [GPTSAN-japanese](../model_doc/gptsan-japanese), [LED](../model_doc/led), [LongT5](../model_doc/longt5), [M2M100](../model_doc/m2m_100), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [NLLB](../model_doc/nllb), [NLLB-MOE](../model_doc/nllb-moe), [Pegasus](../model_doc/pegasus), [PEGASUS-X](../model_doc/pegasus_x), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [SwitchTransformers](../model_doc/switch_transformers), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM-ProphetNet](../model_doc/xlm-prophetnet) +[BART](../model_doc/bart), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [Encoder decoder](../model_doc/encoder-decoder), [FairSeq Machine-Translation](../model_doc/fsmt), [GPTSAN-japanese](../model_doc/gptsan-japanese), [LED](../model_doc/led), [LongT5](../model_doc/longt5), [M2M100](../model_doc/m2m_100), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [NLLB](../model_doc/nllb), [NLLB-MOE](../model_doc/nllb-moe), [Pegasus](../model_doc/pegasus), [PEGASUS-X](../model_doc/pegasus_x), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [SeamlessM4T](../model_doc/seamless_m4t), [SwitchTransformers](../model_doc/switch_transformers), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM-ProphetNet](../model_doc/xlm-prophetnet) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3325a2c77b008c..8001bb33c81408 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -517,6 +517,12 @@ "SamPromptEncoderConfig", "SamVisionConfig", ], + "models.seamless_m4t": [ + "SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP", + "SeamlessM4TConfig", + "SeamlessM4TFeatureExtractor", + "SeamlessM4TProcessor", + ], "models.segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"], "models.sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"], "models.sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"], @@ -805,6 +811,7 @@ _import_structure["models.plbart"].append("PLBartTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer") _import_structure["models.rembert"].append("RemBertTokenizer") + _import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizer") _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") _import_structure["models.speecht5"].append("SpeechT5Tokenizer") _import_structure["models.t5"].append("T5Tokenizer") @@ -877,6 +884,7 @@ _import_structure["models.rembert"].append("RemBertTokenizerFast") _import_structure["models.roberta"].append("RobertaTokenizerFast") _import_structure["models.roformer"].append("RoFormerTokenizerFast") + _import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizerFast") _import_structure["models.splinter"].append("SplinterTokenizerFast") _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast") _import_structure["models.t5"].append("T5TokenizerFast") @@ -1082,6 +1090,7 @@ _import_structure["modeling_utils"] = ["PreTrainedModel"] # PyTorch models structure + _import_structure["models.albert"].extend( [ "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2683,6 +2692,21 @@ "SamPreTrainedModel", ] ) + _import_structure["models.seamless_m4t"].extend( + [ + "SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST", + "SeamlessM4TCodeHifiGan", + "SeamlessM4TForSpeechToSpeech", + "SeamlessM4TForSpeechToText", + "SeamlessM4TForTextToSpeech", + "SeamlessM4TForTextToText", + "SeamlessM4THifiGan", + "SeamlessM4TModel", + "SeamlessM4TPreTrainedModel", + "SeamlessM4TTextToUnitForConditionalGeneration", + "SeamlessM4TTextToUnitModel", + ] + ) _import_structure["models.segformer"].extend( [ "SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4658,6 +4682,12 @@ SamPromptEncoderConfig, SamVisionConfig, ) + from .models.seamless_m4t import ( + SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP, + SeamlessM4TConfig, + SeamlessM4TFeatureExtractor, + SeamlessM4TProcessor, + ) from .models.segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig from .models.sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig from .models.sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig @@ -4925,6 +4955,7 @@ from .models.plbart import PLBartTokenizer from .models.reformer import ReformerTokenizer from .models.rembert import RemBertTokenizer + from .models.seamless_m4t import SeamlessM4TTokenizer from .models.speech_to_text import Speech2TextTokenizer from .models.speecht5 import SpeechT5Tokenizer from .models.t5 import T5Tokenizer @@ -4990,6 +5021,7 @@ from .models.rembert import RemBertTokenizerFast from .models.roberta import RobertaTokenizerFast from .models.roformer import RoFormerTokenizerFast + from .models.seamless_m4t import SeamlessM4TTokenizerFast from .models.splinter import SplinterTokenizerFast from .models.squeezebert import SqueezeBertTokenizerFast from .models.t5 import T5TokenizerFast @@ -5157,8 +5189,6 @@ top_k_top_p_filtering, ) from .modeling_utils import PreTrainedModel - - # PyTorch model imports from .models.albert import ( ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, AlbertForMaskedLM, @@ -6485,6 +6515,21 @@ SamModel, SamPreTrainedModel, ) + + # PyTorch model imports + from .models.seamless_m4t import ( + SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST, + SeamlessM4TCodeHifiGan, + SeamlessM4TForSpeechToSpeech, + SeamlessM4TForSpeechToText, + SeamlessM4TForTextToSpeech, + SeamlessM4TForTextToText, + SeamlessM4THifiGan, + SeamlessM4TModel, + SeamlessM4TPreTrainedModel, + SeamlessM4TTextToUnitForConditionalGeneration, + SeamlessM4TTextToUnitModel, + ) from .models.segformer import ( SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, SegformerDecodeHead, diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 2697d04730efc6..04f3e7e1c58776 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -775,6 +775,31 @@ def post_processor(self): ) +class SeamlessM4TConverter(SpmConverter): + def vocab(self, proto): + vocab = [ + ("", 0.0), + ("", 0.0), + ("", 0.0), + ("", 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + return vocab + + def unk_id(self, proto): + return self.original_tokenizer.unk_token_id + + def post_processor(self): + return processors.TemplateProcessing( + single="__eng__ $A ", + pair="__eng__ $A $B ", + special_tokens=[ + ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + class XLMRobertaConverter(SpmConverter): def vocab(self, proto): vocab = [ @@ -1278,6 +1303,7 @@ def converted(self) -> Tokenizer: "RetriBertTokenizer": BertConverter, "RobertaTokenizer": RobertaConverter, "RoFormerTokenizer": RoFormerConverter, + "SeamlessM4TTokenizer": SeamlessM4TConverter, "SqueezeBertTokenizer": BertConverter, "T5Tokenizer": T5Converter, "WhisperTokenizer": WhisperConverter, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 074bade6929cdd..4093ff819e8026 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -180,6 +180,7 @@ roformer, rwkv, sam, + seamless_m4t, segformer, sew, sew_d, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 0328426d67468b..7308dd27a65aca 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -186,6 +186,7 @@ ("roformer", "RoFormerConfig"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), + ("seamless_m4t", "SeamlessM4TConfig"), ("segformer", "SegformerConfig"), ("sew", "SEWConfig"), ("sew-d", "SEWDConfig"), @@ -392,6 +393,7 @@ ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("rwkv", "RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("seamless_m4t", "SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -622,6 +624,7 @@ ("roformer", "RoFormer"), ("rwkv", "RWKV"), ("sam", "SAM"), + ("seamless_m4t", "SeamlessM4T"), ("segformer", "SegFormer"), ("sew", "SEW"), ("sew-d", "SEW-D"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index befca6a64b81b7..40cdec46bcf642 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -76,6 +76,7 @@ ("pop2piano", "Pop2PianoFeatureExtractor"), ("regnet", "ConvNextFeatureExtractor"), ("resnet", "ConvNextFeatureExtractor"), + ("seamless_m4t", "SeamlessM4TFeatureExtractor"), ("segformer", "SegformerFeatureExtractor"), ("sew", "Wav2Vec2FeatureExtractor"), ("sew-d", "Wav2Vec2FeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index c6fc1dce1b5833..f236b8fc5f58f3 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -175,6 +175,7 @@ ("roformer", "RoFormerModel"), ("rwkv", "RwkvModel"), ("sam", "SamModel"), + ("seamless_m4t", "SeamlessM4TModel"), ("segformer", "SegformerModel"), ("sew", "SEWModel"), ("sew-d", "SEWDModel"), @@ -674,6 +675,7 @@ ("pegasus_x", "PegasusXForConditionalGeneration"), ("plbart", "PLBartForConditionalGeneration"), ("prophetnet", "ProphetNetForConditionalGeneration"), + ("seamless_m4t", "SeamlessM4TForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("umt5", "UMT5ForConditionalGeneration"), @@ -684,6 +686,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( [ ("pop2piano", "Pop2PianoForConditionalGeneration"), + ("seamless_m4t", "SeamlessM4TForSpeechToText"), ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), ("speech_to_text", "Speech2TextForConditionalGeneration"), ("speecht5", "SpeechT5ForSpeechToText"), @@ -1047,6 +1050,7 @@ # Model for Text-To-Waveform mapping ("bark", "BarkModel"), ("musicgen", "MusicgenForConditionalGeneration"), + ("seamless_m4t", "SeamlessM4TForTextToSpeech"), ("vits", "VitsModel"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index e7cb9aa11a7330..03f177ac4b0325 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -71,6 +71,7 @@ ("pix2struct", "Pix2StructProcessor"), ("pop2piano", "Pop2PianoProcessor"), ("sam", "SamProcessor"), + ("seamless_m4t", "SeamlessM4TProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), ("speech_to_text", "Speech2TextProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 78a363f8ab20c8..dff28af0906eb1 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -329,6 +329,13 @@ ("roc_bert", ("RoCBertTokenizer", None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), ("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ( + "seamless_m4t", + ( + "SeamlessM4TTokenizer" if is_sentencepiece_available() else None, + "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/seamless_m4t/__init__.py b/src/transformers/models/seamless_m4t/__init__.py new file mode 100644 index 00000000000000..3167311a5a6ef7 --- /dev/null +++ b/src/transformers/models/seamless_m4t/__init__.py @@ -0,0 +1,111 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_sentencepiece_available, + is_tokenizers_available, + is_torch_available, +) + + +_import_structure = { + "configuration_seamless_m4t": ["SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP", "SeamlessM4TConfig"], + "feature_extraction_seamless_m4t": ["SeamlessM4TFeatureExtractor"], + "processing_seamless_m4t": ["SeamlessM4TProcessor"], +} + +try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_seamless_m4t"] = ["SeamlessM4TTokenizer"] + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_seamless_m4t_fast"] = ["SeamlessM4TTokenizerFast"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_seamless_m4t"] = [ + "SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST", + "SeamlessM4TForTextToSpeech", + "SeamlessM4TForSpeechToSpeech", + "SeamlessM4TForTextToText", + "SeamlessM4TForSpeechToText", + "SeamlessM4TModel", + "SeamlessM4TPreTrainedModel", + "SeamlessM4TCodeHifiGan", + "SeamlessM4THifiGan", + "SeamlessM4TTextToUnitForConditionalGeneration", + "SeamlessM4TTextToUnitModel", + ] + +if TYPE_CHECKING: + from .configuration_seamless_m4t import SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP, SeamlessM4TConfig + from .feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor + from .processing_seamless_m4t import SeamlessM4TProcessor + + try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_seamless_m4t import SeamlessM4TTokenizer + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_seamless_m4t_fast import SeamlessM4TTokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_seamless_m4t import ( + SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST, + SeamlessM4TCodeHifiGan, + SeamlessM4TForSpeechToSpeech, + SeamlessM4TForSpeechToText, + SeamlessM4TForTextToSpeech, + SeamlessM4TForTextToText, + SeamlessM4THifiGan, + SeamlessM4TModel, + SeamlessM4TPreTrainedModel, + SeamlessM4TTextToUnitForConditionalGeneration, + SeamlessM4TTextToUnitModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/seamless_m4t/configuration_seamless_m4t.py b/src/transformers/models/seamless_m4t/configuration_seamless_m4t.py new file mode 100644 index 00000000000000..c3296cfabc76d0 --- /dev/null +++ b/src/transformers/models/seamless_m4t/configuration_seamless_m4t.py @@ -0,0 +1,417 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SeamlessM4T model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/hf-seamless-m4t-medium": "https://huggingface.co/facebook/hf-seamless-m4t-medium/resolve/main/config.json", + # See all SeamlessM4T models at https://huggingface.co/models?filter=seamless_m4t +} + + +class SeamlessM4TConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~SeamlessM4TModel`]. It is used to instantiate an + SeamlessM4T model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the SeamlessM4T + ["facebook/hf-seamless-m4t-medium"](https://huggingface.co/"facebook/hf-seamless-m4t-medium") architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256102): + Vocabulary size of the SeamlessM4T model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`~SeamlessM4TModel`], [`~SeamlessM4TForTextToSpeech`] or + [`~SeamlessM4TForTextToText`]. + t2u_vocab_size (`int`, *optional*, defaults to 10082): + Unit vocabulary size of the SeamlessM4T model. Defines the number of different unit tokens that can be + represented by the `inputs_ids` passed when calling the Text-To-Units sub-model of [`~SeamlessM4TModel`], + [`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`]. + + > Parameters shared across sub-models + + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the "intermediate" layers in the architecture. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model text encoder and decoder might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + encoder_layerdrop (`float`, *optional*, defaults to 0.05): + The LayerDrop probability for the encoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.05): + The LayerDrop probability for the decoders. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + activation_function (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder and feed-forward layers. If string, + `"gelu"`, `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, decoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all attention layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all activation layers in the model. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(d_model). + + > Text encoder and text decoder specific parameters + + encoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer text encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text encoder. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text encoder. + decoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer text decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text decoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text decoder. + decoder_start_token_id (`int`, *optional*, defaults to 3): + If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only + applied in the text decoder. + max_new_tokens (`int`, *optional*, defaults to 256): + The maximum numbers of text tokens to generate, ignoring the number of tokens in the prompt. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the _padding_ text token. Only applied to the text-decoder model. + bos_token_id (`int`, *optional*, defaults to 2): + The id of the _beginning-of-stream_ text token. Only applied to the text-decoder model. + eos_token_id (`int`, *optional*, defaults to 3): + The id of the _end-of-stream_ text token. Only applied to the text-decoder model. + + > Speech encoder specific parameters + + speech_encoder_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer speech encoder. + speech_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer speech encoder. + speech_encoder_intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer speech encoder. + speech_encoder_hidden_act (`str` or `function`, *optional*, defaults to `"swish"`): + The non-linear activation function (function or string) in the speech encoder. If string, `"gelu"`, + `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + speech_encoder_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all layers in the speech encoder. + add_adapter (`bool`, *optional*, defaults to `True`): + Add an adapter layer on top of the speech encoder. + speech_encoder_layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the speech encoder. See the [LayerDrop paper](see + https://arxiv.org/abs/1909.11556) for more details. + feature_projection_input_dim (`int`, *optional*, defaults to 160): + Input dimension of the input feature projection of the speech encoder, i.e the dimension after processing + input audios with [`SeamlessM4TFeatureExtractor`]. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer of the speech encoder. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer of the speech encoder. + adaptor_kernel_size (`int`, *optional*, defaults to 8): + Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adaptor_stride (`int`, *optional*, defaults to 8): + Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adaptor_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all layers in the speech adapter. + num_adapter_layers (`int`, *optional*, defaults to 1): + Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is + True`. + position_embeddings_type (`str`, *optional*, defaults to `"relative"`): + Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left + `None` no relative position embedding is applied. Only applied to the speech encoder. + rotary_embedding_base (`int`, *optional*, defaults to 10000): + If `"rotary"` position embeddings are used, defines the size of the embedding base. Only applied to the + speech encoder. + max_source_positions (`int`, *optional*, defaults to 4096): + if `"relative"` position embeddings are used, defines the maximum source input positions. Only applied to + the speech encoder. + conv_depthwise_kernel_size (`int`, *optional*, defaults to 31): + Kernel size of convolutional depthwise 1D layer in Conformer blocks. Only applied to the speech encoder. + + > Text-To-Unit (t2u) model specific parameters + + t2u_bos_token_id (`int`, *optional*, defaults to 0): + The id of the _beginning-of-stream_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_pad_token_id (`int`, *optional*, defaults to 1): + The id of the _padding_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_eos_token_id (`int`, *optional*, defaults to 2): + The id of the _end-of-stream_ unit token. Only applied to the text-to-unit seq2seq model. + t2u_decoder_start_token_id (`int`, *optional*, defaults to 2): + If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token. Only + applied to the text-to-unit seq2seq model. + t2u_max_new_tokens (`int`, *optional*, defaults to 1024): + The maximum numbers of unit tokens to generate, ignoring the number of tokens in the prompt. Only applied + to the text-to-unit seq2seq model. + t2u_encoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer text-to-unit encoder. + t2u_encoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit encoder. + t2u_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text-to-unit encoder. + t2u_decoder_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer text-to-unit decoder. + t2u_decoder_ffn_dim (`int`, *optional*, defaults to 8192): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer text-to-unit decoder. + t2u_decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer text-to-unit decoder. + t2u_max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model text-to-unit component might ever be used with. Typically set + this to something large just in case (e.g., 512 or 1024 or 2048). + + > Hifi-Gan Vocoder specific parameters + + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the output audio will be generated, expressed in hertz (Hz). + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the hifi-gan upsampling network. Applies to the vocoder only. + upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[5, 4, 4, 2, 2]`): + A tuple of integers defining the stride of each 1D convolutional layer in the vocoder upsampling network. + The length of *upsample_rates* defines the number of convolutional layers and has to match the length of + *upsample_kernel_sizes*. Applies to the vocoder only. + upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[11, 8, 8, 4, 4]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the vocoder upsampling + network. The length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match + the length of *upsample_rates*. Applies to the vocoder only. + resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the vocoder 1D convolutional layers in the multi-receptive + field fusion (MRF) module. Applies to the vocoder only. + resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the vocoder dilated 1D convolutional layers in + the multi-receptive field fusion (MRF) module. Applies to the vocoder only. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation in the vocoder. Applies to the vocoder + only. + unit_hifi_gan_vocab_size (`int`, *optional*, defaults to 10000): + Vocabulary size of the SeamlessM4T vocoder. Defines the number of different unit tokens that can be + represented by the `inputs_ids` passed when calling the vocoder of [`~SeamlessM4TModel`], + [`~SeamlessM4TForSpeechToSpeech`] or [`~SeamlessM4TForTextToSpeech`]. + unit_embed_dim (`int`, *optional*, defaults to 1280): + The projection dimension of the input ids given to the hifi-gan vocoder. Applies to the vocoder only. + lang_embed_dim (`int`, *optional*, defaults to 256): + The projection dimension of the target language given to the hifi-gan vocoder. Applies to the vocoder only. + spkr_embed_dim (`int`, *optional*, defaults to 256): + The projection dimension of the speaker id given to the hifi-gan vocoder. Applies to the vocoder only. + vocoder_num_langs (`int`, *optional*, defaults to 36): + Number of langs supported by the vocoder. Might be different from `t2u_num_langs`. + vocoder_num_spkrs (`int`, *optional*, defaults to 200): + Number of speakers supported by the vocoder. + variance_predictor_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the duration predictor. Applies to the vocoder only. + var_pred_dropout (`float`, *optional*, defaults to 0.5): + The dropout probabilitiy of the duration predictor. Applies to the vocoder only. + vocoder_offset (`int`, *optional*, defaults to 4): + Offset the unit token ids by this number to account for symbol tokens. Applies to the vocoder only. + + ```python + >>> from transformers import SeamlessM4TModel, SeamlessM4TConfig + + >>> # Initializing a SeamlessM4T "facebook/hf-seamless-m4t-medium" style configuration + >>> configuration = SeamlessM4TConfig() + + >>> # Initializing a model from the "facebook/hf-seamless-m4t-medium" style configuration + >>> model = SeamlessM4TModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "seamless_m4t" + + def __init__( + self, + vocab_size=256102, + t2u_vocab_size=10082, + # shared config + hidden_size=1024, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + max_position_embeddings=1024, + is_encoder_decoder=True, + encoder_layerdrop=0.05, + decoder_layerdrop=0.05, + activation_function="relu", + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + scale_embedding=True, + # text encoder|decoder + encoder_layers=24, + encoder_ffn_dim=8192, + encoder_attention_heads=16, + decoder_layers=24, + decoder_ffn_dim=8192, + decoder_attention_heads=16, + decoder_start_token_id=3, + max_new_tokens=256, + pad_token_id=0, + bos_token_id=2, + eos_token_id=3, + # speech_encoder + speech_encoder_layers=24, + speech_encoder_attention_heads=16, + speech_encoder_intermediate_size=4096, + speech_encoder_hidden_act="swish", + speech_encoder_dropout=0.0, + add_adapter=True, + speech_encoder_layerdrop=0.1, + feature_projection_input_dim=160, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + adaptor_kernel_size=8, + adaptor_stride=8, + adaptor_dropout=0.1, + num_adapter_layers=1, + position_embeddings_type="relative", + rotary_embedding_base=10000, + max_source_positions=4096, + conv_depthwise_kernel_size=31, + # t2u config + t2u_bos_token_id=0, + t2u_pad_token_id=1, + t2u_eos_token_id=2, + t2u_decoder_start_token_id=2, + t2u_max_new_tokens=1024, + t2u_encoder_layers=6, + t2u_encoder_ffn_dim=8192, + t2u_encoder_attention_heads=16, + t2u_decoder_layers=6, + t2u_decoder_ffn_dim=8192, + t2u_decoder_attention_heads=16, + t2u_max_position_embeddings=2048, + # hifi-gan vocoder config + sampling_rate=16000, + upsample_initial_channel=512, + upsample_rates=[5, 4, 4, 2, 2], + upsample_kernel_sizes=[11, 8, 8, 4, 4], + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_slope=0.1, + # specific to Code Hifi-Gan + unit_hifi_gan_vocab_size=10000, + unit_embed_dim=1280, + lang_embed_dim=256, + spkr_embed_dim=256, + vocoder_num_langs=36, + vocoder_num_spkrs=200, + variance_predictor_kernel_size=3, + var_pred_dropout=0.5, + vocoder_offset=4, + **kwargs, + ): + # overall_config + self.vocab_size = vocab_size + self.t2u_vocab_size = t2u_vocab_size + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.max_position_embeddings = max_position_embeddings + self.use_cache = use_cache + self.max_new_tokens = max_new_tokens + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.activation_function = activation_function + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.scale_embedding = scale_embedding + # for proper config init + self.num_attention_heads = decoder_attention_heads + self.num_hidden_layers = decoder_layers + + # text|unit encoder|decoder + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_attention_heads = decoder_attention_heads + + # speech_encoder + self.speech_encoder_layers = speech_encoder_layers + self.speech_encoder_hidden_act = speech_encoder_hidden_act + self.speech_encoder_dropout = speech_encoder_dropout + self.speech_encoder_attention_heads = speech_encoder_attention_heads + self.speech_encoder_layerdrop = speech_encoder_layerdrop + self.speech_encoder_intermediate_size = speech_encoder_intermediate_size + self.feature_projection_input_dim = feature_projection_input_dim + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.adaptor_kernel_size = adaptor_kernel_size + self.adaptor_stride = adaptor_stride + self.adaptor_dropout = adaptor_dropout + self.num_adapter_layers = num_adapter_layers + self.position_embeddings_type = position_embeddings_type + self.rotary_embedding_base = rotary_embedding_base + self.max_source_positions = max_source_positions + self.conv_depthwise_kernel_size = conv_depthwise_kernel_size + self.add_adapter = add_adapter + + # t2u config + self.t2u_bos_token_id = t2u_bos_token_id + self.t2u_pad_token_id = t2u_pad_token_id + self.t2u_eos_token_id = t2u_eos_token_id + self.t2u_decoder_start_token_id = t2u_decoder_start_token_id + self.t2u_max_new_tokens = t2u_max_new_tokens + self.t2u_encoder_layers = t2u_encoder_layers + self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim + self.t2u_encoder_attention_heads = t2u_encoder_attention_heads + self.t2u_decoder_layers = t2u_decoder_layers + self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim + self.t2u_decoder_attention_heads = t2u_decoder_attention_heads + self.t2u_max_position_embeddings = t2u_max_position_embeddings + + # hifi-gan vocoder config + # original parameters specific to Hifi-Gan + self.sampling_rate = sampling_rate + self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.leaky_relu_slope = leaky_relu_slope + + # specific to Code Hifi-Gan + self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size + self.unit_embed_dim = unit_embed_dim + self.lang_embed_dim = lang_embed_dim + self.spkr_embed_dim = spkr_embed_dim + self.vocoder_num_langs = vocoder_num_langs + self.vocoder_num_spkrs = vocoder_num_spkrs + self.variance_predictor_kernel_size = variance_predictor_kernel_size + self.var_pred_dropout = var_pred_dropout + self.vocoder_offset = vocoder_offset + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + is_encoder_decoder=is_encoder_decoder, + max_position_embeddings=max_position_embeddings, + **kwargs, + ) diff --git a/src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py b/src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py new file mode 100644 index 00000000000000..e97ca3046a9817 --- /dev/null +++ b/src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py @@ -0,0 +1,410 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Converting Meta SeamlessM4T checkpoints from seamless_communication to HF.""" + + +import argparse +import os +from pathlib import Path + +import torch +from accelerate.utils.modeling import find_tied_parameters +from seamless_communication.models.inference.translator import Translator + +from transformers import ( + SeamlessM4TConfig, + SeamlessM4TFeatureExtractor, + SeamlessM4TModel, + SeamlessM4TProcessor, + SeamlessM4TTokenizer, +) +from transformers.utils import logging + + +# fmt: off +UNIT_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kan__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tam__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__", ] +# fmt: on + +# fmt: off +VOCODER_SUPPORTED_LANGUAGES = ["__arb__", "__ben__", "__cat__", "__ces__", "__cmn__", "__cym__", "__dan__", "__deu__", "__eng__", "__est__", "__fin__", "__fra__", "__hin__", "__ind__", "__ita__", "__jpn__", "__kor__", "__mlt__", "__nld__", "__pes__", "__pol__", "__por__", "__ron__", "__rus__", "__slk__", "__spa__", "__swe__", "__swh__", "__tel__", "__tgl__", "__tha__", "__tur__", "__ukr__", "__urd__", "__uzn__", "__vie__",] +# fmt: on + + +# fmt: off +MEDIUM_SUPPORTED_LANGUAGES = ["ace","ace_Latn","acm","acq","aeb","afr","ajp","aka","amh","apc","arb","ars","ary","arz","asm","ast","awa","ayr","azb","azj","bak","bam","ban","bel","bem","ben","bho","bjn","bjn_Latn","bod","bos","bug","bul","cat","ceb","ces","cjk","ckb","crh","cym","dan","deu","dik","dyu","dzo","ell","eng","epo","est","eus","ewe","fao","pes","fij","fin","fon","fra","fur","fuv","gla","gle","glg","grn","guj","hat","hau","heb","hin","hne","hrv","hun","hye","ibo","ilo","ind","isl","ita","jav","jpn","kab","kac","kam","kan","kas","kas_Deva","kat","knc","knc_Latn","kaz","kbp","kea","khm","kik","kin","kir","kmb","kon","kor","kmr","lao","lvs","lij","lim","lin","lit","lmo","ltg","ltz","lua","lug","luo","lus","mag","mai","mal","mar","min","mkd","plt","mlt","mni","khk","mos","mri","zsm","mya","nld","nno","nob","npi","nso","nus","nya","oci","gaz","ory","pag","pan","pap","pol","por","prs","pbt","quy","ron","run","rus","sag","san","sat","scn","shn","sin","slk","slv","smo","sna","snd","som","sot","spa","als","srd","srp","ssw","sun","swe","swh","szl","tam","tat","tel","tgk","tgl","tha","tir","taq","taq_Tfng","tpi","tsn","tso","tuk","tum","tur","twi","tzm","uig","ukr","umb","urd","uzn","vec","vie","war","wol","xho","ydd","yor","yue","cmn","cmn_Hant","zul",] +# fmt: on + + +# fmt: off +LARGE_SUPPORTED_LANGUAGES = ["afr","amh","arb","ary","arz","asm","azj","bel","ben","bos","bul","cat","ceb","ces","ckb","cmn","cmn_Hant","cym","dan","deu","ell","eng","est","eus","fin","fra","fuv","gaz","gle","glg","guj","heb","hin","hrv","hun","hye","ibo","ind","isl","ita","jav","jpn","kan","kat","kaz","khk","khm","kir","kor","lao","lit","lug","luo","lvs","mai","mal","mar","mkd","mlt","mni","mya","nld","nno","nob","npi","nya","ory","pan","pbt","pes","pol","por","ron","rus","sat","slk","slv","sna","snd","som","spa","srp","swe","swh","tam","tel","tgk","tgl","tha","tur","ukr","urd","uzn","vie","yor","yue","zlm","zul",] +# fmt: on + + +def assert_param_count(model_1, model_2): + count_1 = sum(p[1].numel() for p in model_1.named_parameters() if "final_proj" not in p[0]) + count_2 = sum(p[1].numel() for p in model_2.named_parameters() if "final_proj" not in p[0]) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" + + +def param_count(model): + return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0]) + + +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + else: + device = "cpu" + return torch.device(device) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +vocoder_convert_list = [ + ("ups", "hifi_gan.upsampler"), + ("conv_pre", "hifi_gan.conv_pre"), + ("resblocks", "hifi_gan.resblocks"), + ("conv_post", "hifi_gan.conv_post"), + ("lang", "language_embedding"), + ("spkr", "speaker_embedding"), + ("dict.", "unit_embedding."), + ("dur_predictor.conv1.0", "dur_predictor.conv1"), + ("dur_predictor.conv2.0", "dur_predictor.conv2"), +] + +# order is important +wav2vec_convert_list = [ + ("speech_encoder_frontend.model_dim_proj", "feature_projection.projection"), + ("speech_encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"), + ("speech_encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"), + ("speech_encoder.inner.layers", "encoder.layers"), + ("speech_encoder.inner_layer_norm", "encoder.layer_norm"), + ("speech_encoder.adaptor_layers", "adapter.layers"), + ("inner_proj", "intermediate_dense"), + ("self_attn.output_proj", "self_attn.linear_out"), + ("output_proj", "output_dense"), + ("self_attn.k_proj", "self_attn.linear_k"), + ("self_attn.v_proj", "self_attn.linear_v"), + ("self_attn.q_proj", "self_attn.linear_q"), + ("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"), + ("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"), + ("self_attn.sdpa.r_proj", "self_attn.linear_pos"), + ("conv.pointwise_conv1", "conv_module.pointwise_conv1"), + ("conv.pointwise_conv2", "conv_module.pointwise_conv2"), + ("conv.depthwise_conv", "conv_module.depthwise_conv"), + ("conv.batch_norm", "conv_module.batch_norm"), + ("conv_layer_norm", "conv_module.layer_norm"), + ("speech_encoder.proj1", "intermediate_ffn.intermediate_dense"), + ("speech_encoder.proj2", "intermediate_ffn.output_dense"), + ("speech_encoder.layer_norm", "inner_layer_norm"), +] + +t2u_convert_list = [ + ("t2u_model.final_proj", "lm_head"), + ("t2u_model.", "model."), + ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"), + ("encoder_decoder_attn", "cross_attention"), + ("linear_k", "k_proj"), + ("linear_v", "v_proj"), + ("linear_q", "q_proj"), + ("ffn.inner_proj", "ffn.fc1"), + ("ffn.output_proj", "ffn.fc2"), + ("output_proj", "out_proj"), + ("decoder_frontend.embed", "decoder.embed_tokens"), +] + +text_convert_list = [ + ("text_encoder.", ""), + ("text_decoder.", ""), + ("text_encoder_frontend.embed", "embed_tokens"), + ("text_decoder_frontend.embed", "embed_tokens"), + ("encoder_decoder_attn_layer_norm", "cross_attention_layer_norm"), + ("encoder_decoder_attn", "cross_attention"), + ("linear_k", "k_proj"), + ("linear_v", "v_proj"), + ("linear_q", "q_proj"), + ("ffn.inner_proj", "ffn.fc1"), + ("ffn.output_proj", "ffn.fc2"), + ("output_proj", "out_proj"), + ("final_proj", "lm_head"), +] + +CUR_PATH = os.path.dirname(os.path.abspath(__file__)) +default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") +CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "huggingface", "hub") + + +def _load_hf_config(model_type="medium"): + if model_type == "medium": + kwargs = { + "vocab_size": 256206, + "t2u_vocab_size": 10082, + "hidden_size": 1024, + "max_position_embeddings": 4096, + "encoder_layers": 12, + "decoder_layers": 12, + "encoder_ffn_dim": 4096, + "decoder_ffn_dim": 4096, + "t2u_encoder_layers": 4, + "t2u_decoder_layers": 4, + "speech_encoder_layers": 12, + } + return SeamlessM4TConfig(**kwargs) + else: + return SeamlessM4TConfig() + + +def _convert_model( + original_model, + hf_model, + convert_list, + device, + unwanted_prefix="model.", + filter_state_dict="speech", + exclude_state_dict=None, +): + state_dict = original_model.state_dict() + + # filter func + if isinstance(filter_state_dict, str): + + def filter_func(x): + return filter_state_dict in x[0] + + else: + + def filter_func(item): + if exclude_state_dict is not None and exclude_state_dict in item[0]: + return False + for filter_el in filter_state_dict: + if filter_el in item[0]: + return True + + return False + + state_dict = dict(filter(filter_func, state_dict.items())) + + for k, v in list(state_dict.items()): + new_k = k[len(unwanted_prefix) :] + for old_layer_name, new_layer_name in convert_list: + if old_layer_name in new_k: + new_k = new_k.replace(old_layer_name, new_layer_name) + + # must do it by hand + if ".layer_norm" in new_k and new_k.split(".layer_norm")[0][-1].isnumeric(): + new_k = new_k.replace("layer_norm", "final_layer_norm") + + state_dict[new_k] = state_dict.pop(k) + + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) + extra_keys = set(extra_keys) + missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) + missing_keys = set({k for k in missing_keys if "final_logits_bias" not in k}) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + hf_model.load_state_dict(state_dict, strict=False) + n_params = param_count(hf_model) + + logger.info(f"model loaded: {round(n_params/1e6,1)}M params") + + hf_model.eval() + hf_model.to(device) + del state_dict + + return hf_model + + +def load_model(save_dir, model_type, repo_id): + """ + Meta SeamlessM4T is made of 8 main components: + - speech_encoder (#1) and speech_encoder_frontend (#2) + - t2u_model (#3) + - text_encoder (#4) and text_encoder_frontend (#5) + - text_decoder (#6) [and text_decoder_frontend (#5) = equals to text_encoder_frontend] + - final_proj (#7) + - vocoder (#8) + """ + device = _grab_best_device() + if model_type == "medium": + name = "seamlessM4T_medium" + else: + name = "seamlessM4T_large" + + original_model = Translator(name, "vocoder_36langs", device, torch.float32) + + ######### TOKENIZER + + langs = MEDIUM_SUPPORTED_LANGUAGES if model_type == "medium" else LARGE_SUPPORTED_LANGUAGES + langs = [f"__{lang}__" for lang in langs] + vocab_file = os.path.join(os.path.expanduser("~"), "tokenizer", model_type, "tokenizer.model") + + save_dir = os.path.join(save_dir, name) + Path(save_dir).mkdir(exist_ok=True) + + tokenizer = SeamlessM4TTokenizer(vocab_file, additional_special_tokens=langs) + + sanity_check_lang_id = tokenizer.convert_tokens_to_ids("__fra__") + + tokenizer.save_pretrained(save_dir) + tokenizer = SeamlessM4TTokenizer.from_pretrained(save_dir) + + if sanity_check_lang_id != tokenizer.convert_tokens_to_ids("__fra__"): + raise ValueError( + f"Error in tokenizer saving/loading - __fra__ lang id is not coherent: {sanity_check_lang_id} vs {tokenizer.convert_tokens_to_ids('__fra__')}" + ) + + ####### get language to ids dict + text_decoder_lang_code_to_id = {lang.replace("__", ""): tokenizer.convert_tokens_to_ids(lang) for lang in langs} + # offset: vocoder unit vocab size + 5 (for EOS/PAD/BOS/UNK/MSK) + len(supported_languages) + t2u_lang_code_to_id = { + code.replace("__", ""): i + 10005 + len(UNIT_SUPPORTED_LANGUAGES) + for i, code in enumerate(UNIT_SUPPORTED_LANGUAGES) + } + vocoder_lang_code_to_id = {code.replace("__", ""): i for i, code in enumerate(VOCODER_SUPPORTED_LANGUAGES)} + + ######### FE + + fe = SeamlessM4TFeatureExtractor(language_code=langs) + + fe.save_pretrained(save_dir) + fe = SeamlessM4TFeatureExtractor.from_pretrained(save_dir) + + processor = SeamlessM4TProcessor(feature_extractor=fe, tokenizer=tokenizer) + processor.save_pretrained(save_dir) + processor.push_to_hub(repo_id=repo_id, create_pr=True) + + processor = SeamlessM4TProcessor.from_pretrained(save_dir) + + ######## Model + + # init model + hf_config = _load_hf_config(model_type) + hf_model = SeamlessM4TModel(hf_config) + + hf_model.generation_config.__setattr__("text_decoder_lang_to_code_id", text_decoder_lang_code_to_id) + hf_model.generation_config.__setattr__("t2u_lang_code_to_id", t2u_lang_code_to_id) + hf_model.generation_config.__setattr__("vocoder_lang_code_to_id", vocoder_lang_code_to_id) + + # -1. take care of vocoder + # similarly to speech T5 must apply and remove weight norm + hf_model.vocoder.apply_weight_norm() + hf_model.vocoder = _convert_model( + original_model, + hf_model.vocoder, + vocoder_convert_list, + device, + unwanted_prefix="vocoder.code_generator.", + filter_state_dict="vocoder", + ) + hf_model.vocoder.remove_weight_norm() + + # 1. take care of speech encoder + wav2vec = hf_model.speech_encoder + hf_model.speech_encoder = _convert_model( + original_model, wav2vec, wav2vec_convert_list, device, unwanted_prefix="model.", filter_state_dict="speech" + ) + + # 2. take care of t2u + + hf_model.t2u_model = _convert_model( + original_model, + hf_model.t2u_model, + t2u_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict="t2u_model", + ) + + # 3. take care of text encoder + hf_model.text_encoder = _convert_model( + original_model, + hf_model.text_encoder, + text_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict=["model.text_encoder"], + exclude_state_dict="t2u_model", + ) + + # 4. take care of text decoder + hf_model.text_decoder = _convert_model( + original_model, + hf_model.text_decoder, + text_convert_list, + device, + unwanted_prefix="model.", + filter_state_dict=["model.text_decoder"], + exclude_state_dict="t2u_model", + ) + + # 5. take care of final proj + hf_model.lm_head = _convert_model( + original_model, + hf_model.lm_head, + [("final_proj.", "")], + device, + unwanted_prefix="model.", + filter_state_dict=["model.final_proj"], + exclude_state_dict="t2u_model", + ) + + # sanity check + print(find_tied_parameters(hf_model)) + + count_1 = param_count(hf_model) + count_2 = param_count(original_model) + + print(f"HF MODEL:{count_1}, ORIGINAL_MODEL: {count_2}, diff:{count_1 - count_2}") + print(f"HF MODEL excluding embeddings:{hf_model.num_parameters(exclude_embeddings=True)}") + + del original_model + + hf_model.generation_config._from_model_config = False + hf_model.save_pretrained(save_dir) + hf_model.push_to_hub(repo_id=repo_id, create_pr=True) + hf_model = SeamlessM4TModel.from_pretrained(save_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + + parser.add_argument( + "--model_type", + default="medium", + type=str, + help="Model type.", + ) + + parser.add_argument( + "--save_dir", + default="/home/ubuntu/weights", + type=str, + help="Path to the output PyTorch model.", + ) + + parser.add_argument( + "--repo_id", + default="facebook/hf-seamless-m4t-medium", + type=str, + help="Repo ID.", + ) + + args = parser.parse_args() + + load_model(args.save_dir, args.model_type, args.repo_id) diff --git a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py new file mode 100644 index 00000000000000..852be5c73a1d4a --- /dev/null +++ b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -0,0 +1,305 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for SeamlessM4T +""" + +import copy +from typing import List, Optional, Union + +import numpy as np + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a SeamlessM4T feature extractor. + + This feature extractor inherits from [`SequenceFeatureExtractor`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + num_mel_bins (`int`, *optional*, defaults to 80): + Number of Mel-frequency bins. + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding vectors. + stride (`int`, *optional*, defaults to 2): + Stride used to reshape audios from shape (batch_size,num_frames,num_mel_bins) to + (batch_size,num_frames//stride,num_mel_bins*stride). + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + num_mel_bins=80, + padding_value=0.0, + stride=2, + **kwargs, + ): + self.num_mel_bins = num_mel_bins + self.return_attention_mask = True + self.stride = stride + + mel_filters = mel_filter_bank( + num_frequency_bins=256, + num_mel_filters=self.num_mel_bins, + min_frequency=20, + max_frequency=sampling_rate // 2, + sampling_rate=sampling_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0))) + self.window = window_function(400, "povey", periodic=False) + + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + + @staticmethod + # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm + def zero_mean_unit_var_norm( + input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0 + ) -> List[np.ndarray]: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def _extract_fbank_features( + self, + waveform: np.ndarray, + ) -> np.ndarray: + """ + Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs + and hence the waveform should not be normalized before feature extraction. + """ + # by default, it extracts the left channel if stereo + if len(waveform.shape) == 2: + waveform = waveform[0] + + waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers + features = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + return features + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + max_length: Optional[int] = None, + truncation: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + do_normalize_per_mel_bins: Optional[bool] = True, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays, a list of list of float values or a list of a list of list of float + values. If `raw_speech` is a one-dimensional `np.ndarray` or a `List[float]`, `raw_speech` is + considered a single-channel, single-sample sound. In all other cases, the first dimension of + `raw_speech`, whether from an `np.ndarray` or a `List[...]`, corresponds to the number of samples in + the batch, and the number of channels (i.e. mono or stereo character) is derived from the other + dimensions (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*, defaults to 2): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + + + For SeamlessM4T models, `attention_mask` should always be passed for batched inference, to avoid subtle + bugs. + + + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + do_normalize_per_mel_bins (`bool`, *optional*, defaults to `True`): + Whether or not to zero-mean unit-variance normalize the input per mel-channel. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to the tokenizer or the feature + extractor. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + "It is strongly recommended to pass the `sampling_rate` argument to this function. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 3: + raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}") + + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # extract fbank features + features = [self._extract_fbank_features(waveform) for waveform in raw_speech] + + if do_normalize_per_mel_bins: + # torch defaults to ddof=1, and numpy defaults to ddof=0 + features = [ + (x - np.expand_dims(x.mean(0), 0)) / np.sqrt(np.expand_dims(x.var(0, ddof=1), 0) + 1e-7) + for x in features + ] + + # convert into correct format for padding + encoded_inputs = BatchFeature({"input_features": features}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_tensors="np", + ) + + # SeamlessM4T needs to process extracted features + input_features = padded_inputs.get("input_features") + attention_mask = padded_inputs.get("attention_mask") + + batch_size, num_frames, num_channels = input_features.shape + + remainder = num_frames % self.stride + if remainder != 0: + input_features = input_features[:, :num_frames, :] + attention_mask = attention_mask[:, :num_frames] + + input_features = np.reshape( + input_features, (batch_size, num_frames // self.stride, num_channels * self.stride) + ) + + indices = np.arange(0, num_frames) + attention_mask = attention_mask[:, indices % self.stride == 1] + + padded_inputs["input_features"] = input_features + padded_inputs["attention_mask"] = attention_mask + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["feature_extractor_type"] = self.__class__.__name__ + if "mel_filters" in output: + del output["mel_filters"] + if "window" in output: + del output["window"] + return output diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py new file mode 100755 index 00000000000000..7df6fcd98907ca --- /dev/null +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -0,0 +1,4607 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch SeamlessM4T model.""" + + +import copy +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...deepspeed import is_deepspeed_zero3_enabled +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Wav2Vec2BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_seamless_m4t import SeamlessM4TConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/hf-seamless-m4t-medium" +_CONFIG_FOR_DOC = "SeamlessM4TConfig" + +SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/hf-seamless-m4t-medium", + # See all SeamlessM4T models at https://huggingface.co/models?filter=seamless_m4t +] + +SPEECHT5_PRETRAINED_HIFIGAN_CONFIG_ARCHIVE_MAP = { + "microsoft/speecht5_hifigan": "https://huggingface.co/microsoft/speecht5_hifigan/resolve/main/config.json", +} + + +@dataclass +class SeamlessM4TGenerationOutput(ModelOutput): + """ + Class defining the generated outputs from [`SeamlessM4TModel`], [`SeamlessM4TForTextToText`], + [`SeamlessM4TForTextToSpeech`], [`SeamlessM4TForSpeechToSpeech`] and [`SeamlessM4TForTextToSpeech`]. + + Args: + waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + The final audio waveform predicted by the model. + waveform_lengths (`torch.IntTensor` of shape `(batch_size,)`, *optional*): + The length in samples of each element in the `waveform` batch. + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The generated translated sequences. This is the output of the text-to-text or the speech-to-text models. + The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished + early due to the `eos_token_id`. + unit_sequences (`torch.LongTensor` of shape `(batch_size, unit_sequence_length)`, *optional*): + The generated translated unit sequences. This is the output of the text-to-units model. The second + dimension (unit_sequence_length) is either equal to `t2u_max_length` or shorter if all batches finished + early due to the `t2u_eos_token_id`. + """ + + waveform: Optional[torch.FloatTensor] = None + waveform_lengths: Optional[torch.IntTensor] = None + sequences: Optional[Tuple[torch.FloatTensor]] = None + unit_sequences: Optional[Tuple[torch.FloatTensor]] = None + + +SEAMLESS_M4T_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`~SeamlessM4TConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SEAMLESS_M4T_INPUTS_DOCSTRING_FIRST_PART = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + """ + +SEAMLESS_M4T_INPUTS_DOCSTRING_TEXT_PART = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + """ + +SEAMLESS_M4T_INPUTS_DOCSTRING_SPEECH_PART = r""" + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + """ + +SEAMLESS_M4T_INPUTS_DOCSTRING_LAST_PART = r""" + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape`(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +M4T_MODEL_INPUTS_DOCSTRING = SEAMLESS_M4T_INPUTS_DOCSTRING_FIRST_PART + SEAMLESS_M4T_INPUTS_DOCSTRING_LAST_PART + +M4T_TEXT_INPUTS_DOCSTRING = SEAMLESS_M4T_INPUTS_DOCSTRING_TEXT_PART + SEAMLESS_M4T_INPUTS_DOCSTRING_LAST_PART + +M4T_SPEECH_INPUTS_DOCSTRING = SEAMLESS_M4T_INPUTS_DOCSTRING_SPEECH_PART + SEAMLESS_M4T_INPUTS_DOCSTRING_LAST_PART + + +############ UTILS ################ + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): + """ + Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that + stops at the corresponding element in `seq_lens`. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, *)`): + The sequences to mask, where `*` is any number of sequence-specific dimensions including none. + seq_lens (`torch.Tensor` of shape `(batch)`: + Each element represents the length of the sequence at the same index in `hidden_states` + + Returns: + `torch.FloatTensor`: The float attention mask of shape `(batch, seq_len)` + """ + batch_size, mask_seq_len = hidden_states.shape[:2] + + indices = torch.arange(mask_seq_len, device=seq_lens.device).expand(batch_size, -1) + + bool_mask = indices >= seq_lens.unsqueeze(1).expand(-1, mask_seq_len) + + mask = hidden_states.new_ones((batch_size, mask_seq_len)) + + mask = mask.masked_fill(bool_mask, 0) + + return mask + + +def format_speech_generation_kwargs(kwargs): + """ + Format kwargs for SeamlessM4T models that generate speech, attribute kwargs to either the text generation or the + speech generation models. + + Args: + kwargs (`dict`)`: + Keyword arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + """ + # attribute kwargs to models + kwargs_text = {} + kwargs_speech = {} + for key, value in kwargs.items(): + if key.startswith("text_"): + key = key[len("text_") :] + kwargs_text[key] = value + elif key.startswith("speech_"): + key = key[len("speech_") :] + kwargs_speech[key] = value + else: + # If the key is already in a specific config, then it's been set with a + # submodules specific value and we don't override + if key not in kwargs_text: + kwargs_text[key] = value + if key not in kwargs_speech: + kwargs_speech[key] = value + return kwargs_text, kwargs_speech + + +############ SPEECH ENCODER related code ################ + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SeamlessM4TConformer, feat_extract_activation->speech_encoder_hidden_act +class SeamlessM4TConformerPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + deepspeed.zero.register_external_parameter(self, self.conv.weight_v) + deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SeamlessM4TConformerSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRotaryPositionalEmbedding with Wav2Vec2->SeamlessM4T, num_attention_heads->speech_encoder_attention_heads +class SeamlessM4TConformerRotaryPositionalEmbedding(nn.Module): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.speech_encoder_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + # Embeddings are computed in the dtype of the inv_freq constant + time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + # Computed embeddings are cast to the dtype of the hidden state inputs + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) + return self.cached_rotary_positional_embedding + + +# Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerRelPositionalEmbedding with Wav2Vec2->SeamlessM4T +class SeamlessM4TConformerRelPositionalEmbedding(nn.Module): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.max_source_positions + self.d_model = config.hidden_size + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (iSeamlessM4T +class SeamlessM4TConformerSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class SeamlessM4TConformerFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.feature_projection_input_dim, eps=config.layer_norm_eps) + self.projection = nn.Linear(config.feature_projection_input_dim, config.hidden_size) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SeamlessM4TConformerFeedForward(nn.Module): + def __init__(self, config, act_fn=None, dropout=None): + super().__init__() + dropout = dropout if dropout is not None else config.speech_encoder_dropout + act_fn = act_fn if act_fn is not None else config.speech_encoder_hidden_act + + self.intermediate_dropout = nn.Dropout(dropout) + self.intermediate_dense = nn.Linear(config.hidden_size, config.speech_encoder_intermediate_size) + self.intermediate_act_fn = ACT2FN[act_fn] if isinstance(act_fn, str) else act_fn + + self.output_dense = nn.Linear(config.speech_encoder_intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SeamlessM4TConformerConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.pointwise_conv1 = nn.Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + config.conv_depthwise_kernel_size, + stride=1, + padding="same", + groups=config.hidden_size, + bias=False, + ) + self.batch_norm = nn.BatchNorm1d(config.hidden_size) + self.activation = ACT2FN[config.speech_encoder_hidden_act] + self.pointwise_conv2 = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = nn.Dropout(config.speech_encoder_dropout) + + def forward(self, hidden_states, attention_mask=None): + hidden_states = self.layer_norm(hidden_states) + + # Ensure that we do not leak padded positions in depthwise convolution. + # Put 0 where necessary + if attention_mask is not None: + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class SeamlessM4TConformerSelfAttention(nn.Module): + """Construct a SeamlessM4TConformerSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config, use_position_embeddings=True): + super().__init__() + + self.head_size = config.hidden_size // config.speech_encoder_attention_heads + self.num_heads = config.speech_encoder_attention_heads + self.position_embeddings_type = config.position_embeddings_type if use_position_embeddings else None + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(p=config.speech_encoder_dropout) + + if self.position_embeddings_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_embeddings_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" + ) + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.position_embeddings_type == "relative": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type ==" + " 'relative'" + ) + # apply relative_position_embeddings to qk scores + # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860 + scores = self._apply_relative_embeddings( + query=query, key=key, relative_position_embeddings=relative_position_embeddings + ) + else: + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + + # apply attention_mask if necessary + if attention_mask is not None: + scores = scores + attention_mask + + # => (batch, head, time1, time2) + probs = torch.softmax(scores, dim=-1) + probs = self.dropout(probs) + + # => (batch, head, time1, d_k) + hidden_states = torch.matmul(probs, value) + + # => (batch, time1, hidden_size) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_rotary_embedding + def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): + batch_size, sequence_length, hidden_size = hidden_states.size() + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) + + cos = relative_position_embeddings[0, :sequence_length, ...] + sin = relative_position_embeddings[1, :sequence_length, ...] + + # rotate hidden_states with rotary embeddings + hidden_states = hidden_states.transpose(0, 1) + rotated_states_begin = hidden_states[..., : self.head_size // 2] + rotated_states_end = hidden_states[..., self.head_size // 2 :] + rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) + hidden_states = (hidden_states * cos) + (rotated_states * sin) + hidden_states = hidden_states.transpose(0, 1) + + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) + + return hidden_states + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerSelfAttention._apply_relative_embeddings + def _apply_relative_embeddings(self, query, key, relative_position_embeddings): + # 1. project positional embeddings + # => (batch, head, 2*time1-1, d_k) + proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) + proj_relative_position_embeddings = proj_relative_position_embeddings.view( + relative_position_embeddings.size(0), -1, self.num_heads, self.head_size + ) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) + + # 2. Add bias to query + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + # 3. attention score: first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # => (batch, head, time1, time2) + scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + # 4. then compute matrix b and matrix d + # => (batch, head, time1, 2*time1-1) + scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings) + + # 5. shift matrix b and matrix d + zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype) + scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1) + scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) + scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) + scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) + scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1] + + # 6. sum matrices + # => (batch, head, time1, time2) + scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) + + return scores + + +class SeamlessM4TConformerEncoderLayer(nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4T, attention_dropout->speech_encoder_dropout, torch.nn->nn + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.speech_encoder_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = SeamlessM4TConformerFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = nn.Dropout(dropout) + self.self_attn = SeamlessM4TConformerSelfAttention(config) + + # Conformer Convolution + self.conv_module = SeamlessM4TConformerConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = SeamlessM4TConformerFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + conv_attention_mask: Optional[torch.Tensor] = None, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class SeamlessM4TConformerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + if config.position_embeddings_type == "relative": + self.embed_positions = SeamlessM4TConformerRelPositionalEmbedding(config) + elif config.position_embeddings_type == "rotary": + self.embed_positions = SeamlessM4TConformerRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + self.dropout = nn.Dropout(config.speech_encoder_dropout) + self.layers = nn.ModuleList( + [SeamlessM4TConformerEncoderLayer(config) for _ in range(config.speech_encoder_layers)] + ) + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = ( + True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False + ) + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + relative_position_embeddings, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SeamlessM4TConformerAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.adaptor_dropout + + self.kernel_size = config.adaptor_kernel_size + self.stride = config.adaptor_stride + + # 1. residual convolution + self.residual_layer_norm = nn.LayerNorm(embed_dim) + self.residual_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.activation = nn.GLU(dim=1) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_conv = nn.Conv1d( + embed_dim, + 2 * embed_dim, + self.kernel_size, + stride=self.stride, + padding=self.stride // 2, + ) + self.self_attn = SeamlessM4TConformerSelfAttention(config, use_position_embeddings=False) + self.self_attn_dropout = nn.Dropout(dropout) + + # Feed-forward + self.ffn_layer_norm = nn.LayerNorm(embed_dim) + self.ffn = SeamlessM4TConformerFeedForward(config, act_fn="relu", dropout=dropout) + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + pad = self.kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 + + return seq_lens.floor() + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + residual = self.residual_layer_norm(hidden_states) + + # Apply pooling to the residual to match the sequence length of the + # multi-head attention output. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + residual = residual.transpose(1, 2) + residual = self.residual_conv(residual) + residual = self.activation(residual) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + residual = residual.transpose(1, 2) + + hidden_states = self.self_attn_layer_norm(hidden_states) + # Apply pooling before feeding to the multihead-attention layer. + # (batch, seq_len, feature_dim) -> (batch, feature_dim, seq_len) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.self_attn_conv(hidden_states) + hidden_states = self.activation(hidden_states) + # (batch, feature_dim, seq_len) -> (batch, seq_len, feature_dim) + hidden_states = hidden_states.transpose(1, 2) + + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + hidden_states.device + ) + attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) + attention_mask = _expand_mask( + attention_mask, + hidden_states.dtype, + ) + + # The rest of the computation is identical to a vanilla Transformer + # encoder layer. + hidden_states, attn_weigths = self.self_attn( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + residual + + return hidden_states + + +class SeamlessM4TConformerAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + self.layers = nn.ModuleList(SeamlessM4TConformerAdapterLayer(config) for _ in range(config.num_adapter_layers)) + + def forward(self, hidden_states, attention_mask): + # down project hidden_states if necessary + + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + + return hidden_states + + +############ TEXT / UNITS related code ################ + + +# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding +class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward( + self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0 + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->SeamlessM4T,key_value_states->encoder_hidden_states +class SeamlessM4TAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if encoder_hidden_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = encoder_hidden_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `encoder_hidden_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4T,DenseActDense->FeedForwardNetwork, d_model->hidden_size +class SeamlessM4TFeedForwardNetwork(nn.Module): + def __init__(self, config: SeamlessM4TConfig, ffn_dim: int): + super().__init__() + self.fc1 = nn.Linear(config.hidden_size, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, config.hidden_size) + self.dropout = nn.Dropout(config.activation_dropout) + self.act = ACT2FN[config.activation_function] + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.fc2.weight, torch.Tensor) + and hidden_states.dtype != self.fc2.weight.dtype + and self.fc2.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.fc2.weight.dtype) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SeamlessM4TEncoderLayer(nn.Module): + def __init__(self, config: SeamlessM4TConfig, encoder_ffn_dim=None, encoder_attention_heads=None): + super().__init__() + encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim + encoder_attention_heads = ( + config.encoder_attention_heads if encoder_attention_heads is None else encoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4TAttention( + embed_dim=self.embed_dim, + num_heads=encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.attn_dropout = nn.Dropout(config.dropout) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4TFeedForwardNetwork(config, ffn_dim=encoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SeamlessM4TDecoderLayer(nn.Module): + def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None): + super().__init__() + decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim + decoder_attention_heads = ( + config.decoder_attention_heads if decoder_attention_heads is None else decoder_attention_heads + ) + + self.embed_dim = config.hidden_size + self.self_attn = SeamlessM4TAttention( + embed_dim=self.embed_dim, + num_heads=decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.attn_dropout = nn.Dropout(config.dropout) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.cross_attention = SeamlessM4TAttention( + self.embed_dim, decoder_attention_heads, config.attention_dropout, is_decoder=True + ) + self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) + + self.ffn = SeamlessM4TFeedForwardNetwork(config, ffn_dim=decoder_ffn_dim) + + self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) + self.ffn_dropout = nn.Dropout(config.activation_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): + encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by + very large negative values. + layer_head_mask (`torch.FloatTensor`): + mask for attention heads in a given layer of size `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): + mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): + cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.cross_attention_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + past_key_value=cross_attn_past_key_value, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.attn_dropout(hidden_states) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value += cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + + hidden_states = self.ffn_layer_norm(hidden_states) + + hidden_states = self.ffn(hidden_states) + hidden_states = self.ffn_dropout(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +############ SUB-MODELS related code ################ + + +class SeamlessM4TPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SeamlessM4TConfig + base_model_prefix = "seamless_m4t" + supports_gradient_checkpointing = True + _no_split_modules = ["SeamlessM4TEncoderLayer", "SeamlessM4TDecoderLayer", "SeamlessM4TConformerEncoderLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, SeamlessM4TConformerSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, SeamlessM4TConformerPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, SeamlessM4TConformerFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TSpeechEncoder)): + module.gradient_checkpointing = value + + def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): + kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride + pad = kernel_size // 2 + seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) + + seq_lens = ((seq_lens + 2 * pad - kernel_size) / stride) + 1 + + return seq_lens.floor() + + def compute_last_hidden_states_per_sample( + self, + hidden_states: Tuple[Tuple[torch.Tensor]], + beam_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Computes the last hidden states. + + Parameters: + hidden_states (`Tuple[Tuple[torch.Tensor]]`): + The generated hidden states. Tuple (one element for each generated token) of tuples (one element for + each layer of the decoder) of torch.FloatTensor of shape (batch_size*num_beams*num_return_sequences, + generated_length, hidden_size). + beam_indices (`torch.LongTensor`, *optional*): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at + generate-time. + + Return: + `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length, hidden_size)` + containing + the last hidden states. + ```""" + # 1. First, let's compute last_hidden_states from hidden_states. + # For each generation step, takes the hidden state from the last layer. + # shape: (batch_size*vocab_size*num_return_sequences, # generation_steps, hidden_dim) + last_hidden_states = torch.concat([hidden_states[-1] for hidden_states in hidden_states], dim=1) + + # 2. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent + # to a beam search approach were the first (and only) beam is always selected + # in that case, return directly last_hidden_states + if beam_indices is None: + return last_hidden_states + + # 3. cut beam_indices to longest beam length + beam_indices_mask = beam_indices < 0 + max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() + beam_indices = beam_indices.clone()[:, :max_beam_length] + beam_indices_mask = beam_indices_mask[:, :max_beam_length] + + # 4. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards anyways + beam_indices[beam_indices_mask] = 0 + + # 5. expand beam_indices to last_hidden_states dim + beam_indices = beam_indices.unsqueeze(-1) + beam_indices = beam_indices.expand(-1, -1, last_hidden_states.shape[-1]) + + # 6. select the right candidate for each beam + # in other words, new_last_hidden_states[i,j,k] = last_hidden_states[beam_indices[i,j,k], j, k] for all i, j, k + last_hidden_states = torch.gather(last_hidden_states, 0, beam_indices) + + return last_hidden_states + + +@add_start_docstrings( + """Transformer speech encoder consisting of *config.speech_encoder_layers* conformer self attention layers. + Each layer is a [`SeamlessM4TConformerEncoderLayer`].""", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TSpeechEncoder(SeamlessM4TPreTrainedModel): + main_input_name = "input_features" + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.feature_projection = SeamlessM4TConformerFeatureProjection(config) + self.encoder = SeamlessM4TConformerEncoder(config) + self.intermediate_ffn = SeamlessM4TConformerFeedForward(config, act_fn="relu", dropout=0.0) + self.adapter = SeamlessM4TConformerAdapter(config) if config.add_adapter else None + self.inner_layer_norm = nn.LayerNorm(config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_features: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_features is None: + raise ValueError( + """Both `input_features` and `inputs_embeds` are `None` in `SeamlessM4TSpeechEncoder.forward`. + Make sure one of them is not `None`.""" + ) + + hidden_states = self.feature_projection(input_features) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + expanded_hidden_states = self.intermediate_ffn(hidden_states) + hidden_states = hidden_states + 0.5 * expanded_hidden_states + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states, attention_mask=attention_mask) + + hidden_states = self.inner_layer_norm(hidden_states) + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# inspired from MBart and NllbMoe +@add_start_docstrings( + "Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`SeamlessM4TEncoderLayer`].", + SEAMLESS_M4T_START_DOCSTRING, + """ + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + is_t2u_encoder (`bool`, *optional*, defaults to `False`): + indicates if it belongs to the text-to-units model, in which case it won't have input embeddings + """, +) +class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens: Optional[nn.Embedding] = None, + is_t2u_encoder: bool = False, + ): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + embed_dim = config.hidden_size + + self.is_t2u_encoder = is_t2u_encoder + self.max_source_positions = config.max_position_embeddings + + if not self.is_t2u_encoder: + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( + self.max_source_positions, + embed_dim, + self.padding_idx, + ) + + layers = [] + for _ in range(config.encoder_layers): + layers.append( + SeamlessM4TEncoderLayer( + config, + encoder_attention_heads=config.encoder_attention_heads, + encoder_ffn_dim=config.encoder_ffn_dim, + ) + ) + + self.layers = nn.ModuleList(layers) + + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and self.is_t2u_encoder: + raise ValueError( + "You cannot pass input_ids to the encoder of the text_to_units model. Pass inputs_embeds instead." + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + if not self.is_t2u_encoder: + embed_pos = self.embed_positions(input) + + hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device) + else: + hidden_states = inputs_embeds + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +@add_start_docstrings( + "Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SeamlessM4TDecoderLayer`].", + SEAMLESS_M4T_START_DOCSTRING, + """ + embed_tokens (`nn.Embedding`, *optional*): + Input embedding + """, +) +class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens: Optional[nn.Embedding] = None, + ): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + # if embed_tokens defined, use its shape instead + self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight + else: + self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) + + self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( + self.max_target_positions, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + layers = [] + for _ in range(config.decoder_layers): + layers.append( + SeamlessM4TDecoderLayer( + config, + decoder_attention_heads=config.decoder_attention_heads, + decoder_ffn_dim=config.decoder_ffn_dim, + ) + ) + self.layers = nn.ModuleList(layers) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {attn_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[1],) + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[3],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "Transformer bare text-to-unit encoder-decoder. The encoder is a [`SeamlessM4TEncoder`] without embeddings and the decoder is a [`SeamlessM4TDecoder`].", + SEAMLESS_M4T_START_DOCSTRING, + """ + embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. + """, +) +class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel): + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + super().__init__(config) + + self.encoder = SeamlessM4TEncoder(config, is_t2u_encoder=True) + self.decoder = SeamlessM4TDecoder(config, embed_tokens_decoder) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model.forward + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "Transformer text-to-unit encoder-decoder with a language model head. The base encoder-decoder model is a [`SeamlessM4TTextToUnit`].", + SEAMLESS_M4T_START_DOCSTRING, + """ + embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. + """, +) +class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel): + _keys_to_ignore_on_load_missing = [ + "vocoder", + "speech_encoder", + "text_encoder", + "text_decoder", + ] + _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__( + self, + config: SeamlessM4TConfig, + embed_tokens_decoder: Optional[nn.Embedding] = None, + ): + # update config - used principaly for bos_token_id etc. + config = copy.deepcopy(config) + for param, val in config.to_dict().items(): + if param.startswith("t2u_"): + config.__setattr__(param[4:], val) + super().__init__(config) + + self.model = SeamlessM4TTextToUnitModel(config, embed_tokens_decoder) + + self.lm_head = nn.Linear(config.hidden_size, config.t2u_vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + @add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + def _tie_weights(self) -> None: + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + +############ VOCODER related code ################ + + +HIFIGAN_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SeamlessM4TConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock +class HifiGanResidualBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + for layer in self.convs1: + nn.utils.weight_norm(layer) + for layer in self.convs2: + nn.utils.weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def forward(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class SeamlessM4TVariancePredictor(nn.Module): + def __init__(self, config): + super().__init__() + + embed_dim = config.unit_embed_dim + kernel_size = config.variance_predictor_kernel_size + var_pred_dropout = config.var_pred_dropout + + self.conv1 = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + self.activation_fuction = nn.ReLU() + self.ln1 = nn.LayerNorm(embed_dim) + self.dropout_module = nn.Dropout(p=var_pred_dropout) + self.conv2 = nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=1, + ) + self.ln2 = nn.LayerNorm(embed_dim) + self.proj = nn.Linear(embed_dim, 1) + + def forward(self, hidden_states: Tensor) -> Tensor: + # Input: B x T x C; Output: B x T + hidden_states = self.conv1(hidden_states.transpose(1, 2)) + hidden_states = self.activation_fuction(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln1(hidden_states)) + hidden_states = self.conv2(hidden_states.transpose(1, 2)) + hidden_states = self.activation_fuction(hidden_states).transpose(1, 2) + hidden_states = self.dropout_module(self.ln2(hidden_states)) + return self.proj(hidden_states).squeeze(dim=2) + + +class SeamlessM4THifiGan(nn.Module): + def __init__(self, config: SeamlessM4TConfig): + super().__init__() + model_in_dim = config.unit_embed_dim + config.lang_embed_dim + config.spkr_embed_dim + self.leaky_relu_slope = config.leaky_relu_slope + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + model_in_dim, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.ModuleList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3) + + def forward(self, input_embeds: torch.FloatTensor) -> torch.FloatTensor: + r""" + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. + + Args: + spectrogram (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + model_in_dim)`, or un-batched and of shape `(sequence_length, model_in_dim)`. Note that `model_in_dim` + is the sum of `config.unit_embed_dim`, `config.lang_embed_dim` and `config.spkr_embed_dim`. + + Returns: + `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. + """ + + hidden_states = self.conv_pre(input_embeds) + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + + return waveform + + +@add_start_docstrings( + """Code HiFi-GAN vocoder as described in this [repository](https://github.com/facebookresearch/speech-resynthesis).""", + HIFIGAN_START_DOCSTRING, +) +class SeamlessM4TCodeHifiGan(PreTrainedModel): + config_class = SeamlessM4TConfig + main_input_name = "input_embeds" + + def __init__(self, config): + super().__init__(config) + + self.pad_token_id = config.t2u_pad_token_id + self.dur_predictor = SeamlessM4TVariancePredictor(config) + + self.unit_embedding = nn.Embedding(config.unit_hifi_gan_vocab_size, config.unit_embed_dim) + self.speaker_embedding = nn.Embedding(config.vocoder_num_spkrs, config.spkr_embed_dim) + self.language_embedding = nn.Embedding(config.vocoder_num_langs, config.lang_embed_dim) + + self.hifi_gan = SeamlessM4THifiGan(config) + + # Initialize weights and apply final processing + self.post_init() + + def _get_dur_output_lengths(self, input_ids, dur_out): + """ + Computes the output length after the duration layer. + """ + unit_lengths = (input_ids != self.pad_token_id).sum(1) + + # take care of edge cases where no padding or too many padding + unit_lengths = torch.clamp(unit_lengths, 0, dur_out.shape[1] - 1) + + cumulative_dur_out = torch.cumsum(dur_out, dim=1) + unit_lengths = cumulative_dur_out.gather(dim=1, index=unit_lengths.unsqueeze(1)).squeeze() + + return unit_lengths + + def _get_output_hifigan_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the hifigan convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return ( + torch.div(input_length + 2 * pad - dilation * (kernel_size - 1) - 1, stride, rounding_mode="floor") + 1 + ) + + def _transpose_conv_out_length(input_length, kernel_size, stride, pad, dilation=1): + return (input_length - 1) * stride - 2 * pad + dilation * (kernel_size - 1) + 1 + + # conv_pre + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + # upsampler + for i, (upsample_rate, kernel_size) in enumerate( + zip(self.config.upsample_rates, self.config.upsample_kernel_sizes) + ): + input_lengths = _transpose_conv_out_length( + input_lengths, kernel_size, upsample_rate, (kernel_size - upsample_rate) // 2 + ) + + # resblock + for i in range(len(self.config.upsample_rates)): + for kernel_size, dilation in zip(self.config.resblock_kernel_sizes, self.config.resblock_dilation_sizes): + for dil in dilation: + input_lengths = _conv_out_length( + input_lengths, kernel_size, 1, (kernel_size - 1) * dil // 2, dilation=dil + ) + + for dil in dilation: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1, (kernel_size - 1) // 2, dilation=1) + + # conv_post + input_lengths = _conv_out_length(input_lengths, 7, 1, 3) + + return input_lengths + + def forward( + self, input_ids: torch.LongTensor, spkr_id: torch.Tensor, lang_id: torch.Tensor + ) -> Tuple[torch.Tensor]: + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTextToUnitForConditionalGeneration`]. [What are input + IDs?](../glossary#input-ids) + spkr_id (`int`, *optional*): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + tgt_lang (`str`, *optional*): + The language id to use as target language for translation. + """ + hidden_states = self.unit_embedding(input_ids).transpose(1, 2) + spkr = self.speaker_embedding(spkr_id).transpose(1, 2) + lang = self.language_embedding(lang_id).transpose(1, 2) + + log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2)) + dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1) + # B x C x T + if hidden_states.size(0) == 1: + hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2) + else: + # if batched sample, need to interleave per sample, and pad -> loss of parallelism + if hidden_states.shape[0] > 1 and self.training: + logger.warning( + """`self.training=True` and you use batching. You lose parallelism during the hifigan + forward pass because the samples are interleaved.""" + ) + hidden_states = [ + torch.repeat_interleave(hidden_state, duration, dim=-1).transpose(0, 1) + for (hidden_state, duration) in zip(hidden_states, dur_out) + ] + + hidden_states = nn.utils.rnn.pad_sequence(hidden_states, batch_first=True).transpose(1, 2) + + spkr = spkr.repeat(1, 1, hidden_states.shape[-1]) + lang = lang.repeat(1, 1, hidden_states.shape[-1]) + hidden_states = torch.cat([lang, hidden_states, spkr], dim=1) + + hidden_states = self.hifi_gan(hidden_states) + + unit_lengths = self._get_dur_output_lengths(input_ids, dur_out) + lengths = self._get_output_hifigan_lengths(unit_lengths) + + return hidden_states, lengths + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def apply_weight_norm(self): + nn.utils.weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + nn.utils.weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.apply_weight_norm() + nn.utils.weight_norm(self.hifi_gan.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.hifi_gan.conv_pre) + for layer in self.hifi_gan.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.hifi_gan.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.hifi_gan.conv_post) + + +############ WHOLE MODEL related code ################ + + +@add_start_docstrings( + "The text-to-text SeamlessM4T Model transformer which can be used for T2TT.", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel): + _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + @add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_ids=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.Tensor` of varying shape depending on the modality, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + # prepare text_decoder_input_ids + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {', '.join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + + return super().generate( + input_ids, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The speech-to-text SeamlessM4T Model transformer which can be used for S2TT.", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel): + _keys_to_ignore_on_load_missing = ["text_decoder", "t2u_model", "vocoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.speech_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + @add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING) + def forward( + self, + input_features: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def generate( + self, + input_features=None, + tgt_lang=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + **kwargs, + ): + """ + Generates sequences of token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + inputs = kwargs.get("input_embeds") if input_features is None else input_features + inputs = ( + inputs + if inputs is not None + else kwargs.get("encoder_outputs", {"last_hidden_state": None})["last_hidden_state"] + ) + batch_size = len(inputs) + + if hasattr(self.generation_config, "text_decoder_lang_to_code_id"): + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + if tgt_lang not in self.generation_config.text_decoder_lang_to_code_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. Please specify a `tgt_lang` in + {', '.join(self.generation_config.text_decoder_lang_to_code_id.keys())}""" + ) + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + else: + raise ValueError( + """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps + the target language to the right token id. Make sure to load the right generation config.""" + ) + else: + # only a warning, otherwise errors appear in the tests + logger.warning( + """You must either specify a `tgt_lang` or pass a correct `text_decoder_input_ids` to get + a correct generation, otherwise the generation will probably make no sense.""" + ) + return super().generate( + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=text_decoder_input_ids, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The text-to-speech SeamlessM4T Model transformer which can be used for T2ST.", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel): + _keys_to_ignore_on_load_missing = ["speech_encoder"] + main_input_name = "input_ids" + + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config: SeamlessM4TConfig): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def get_encoder(self): + return self.text_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + @add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForTextToText`." + "It doesn't use the text-to-unit model `SeamlessM4TTextToUnitForConditionalGeneration`." + "If you want to generate speech, use the `.generate` method." + ) + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)`and and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {','.join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_ids, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=kwargs_text.get("decoder_head_mask"), + cross_attn_head_mask=kwargs_text.get("cross_attn_head_mask"), + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to( + self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + "The speech-to-speech SeamlessM4T Model transformer which can be used for S2ST.", + SEAMLESS_M4T_START_DOCSTRING, +) +class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel): + _keys_to_ignore_on_load_missing = ["text_encoder"] + main_input_name = "input_features" + + _tied_weights_keys = [ + "lm_head.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config): + super().__init__(config) + + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def get_encoder(self): + return self.speech_encoder + + def get_decoder(self): + return self.text_decoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_decoder.embed_tokens = value + + @add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING) + def forward( + self, + input_features: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This is the same forward method as `SeamlessM4TForSpeechToText`. It doesn't use `self.t2u_model`." + "If you want to generate speech, use the `generate` method." + ) + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_features: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_features, num_beams=4, speech_do_sample=True)` will successively perform + beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor]]`: + - If `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If not `return_intermediate_token_ids`, returns a tuple composed of waveforms of shape `(batch_size, + sequence_length)`and and `waveform_lengths` which gives the length of each sample. + """ + batch_size = len(input_features) if input_features is not None else len(kwargs.get("inputs_embeds")) + + if tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + else: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {','.join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + text_generation_output = super().generate(input_features, **kwargs_text) + sequences = text_generation_output.sequences + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get last_hidden_state from encoder + encoder_hidden_states = self.speech_encoder(input_features=input_features, attention_mask=attention_mask)[0] + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_hidden_states.device + ) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=kwargs_text.get("decoder_head_mask"), + cross_attn_head_mask=kwargs_text.get("cross_attn_head_mask"), + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to( + self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + +@add_start_docstrings( + "The original SeamlessM4T Model transformer which can be used for every tasks available (S2ST, S2TT, T2TT, T2ST).", + SEAMLESS_M4T_START_DOCSTRING, + """ + current_modality (`str`, *optional*, defaults to `"text"`): + Default modality. Used to initialize the model. + """, +) +class SeamlessM4TModel(SeamlessM4TPreTrainedModel): + _tied_weights_keys = [ + "lm_head.weight", + "text_encoder.embed_tokens.weight", + "text_decoder.embed_tokens.weight", + ] + + def __init__(self, config, current_modality="text"): + super().__init__(config) + + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.speech_encoder = SeamlessM4TSpeechEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + self.current_modality = current_modality + if current_modality == "speech": + self.main_input_name = "input_features" + + # these models already call post_init in their initialization + self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) + self.vocoder = SeamlessM4TCodeHifiGan(config) + + def set_modality(self, modality="text"): + if modality == "text": + self.main_input_name = "input_ids" + self.current_modality = "text" + elif modality == "speech": + self.main_input_name = "input_features" + self.current_modality = "speech" + else: + raise ValueError(f"`modality={modality}` is not a valid modality. It must be `text` or `speech`.") + + def get_encoder(self): + if self.current_modality == "text": + return self.text_encoder + else: + return self.speech_encoder + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.text_decoder.embed_tokens + + def set_input_embeddings(self, value): + self.text_encoder.embed_tokens = value + self.text_decoder.embed_tokens = value + self.shared = value + + @add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + if input_ids is None and input_features is None and inputs_embeds is None and encoder_outputs is None: + raise ValueError( + "`input_ids`,`input_features`, `inputs_embeds` and `encoder_outputs` are all empty. Make sure at least one of them is not." + ) + elif input_features is not None: + if input_ids is not None: + logger.warning( + "`input_ids` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through the `speech_encoder`. " + "Make sure that `input_features` and `input_ids` are mutually exclusive." + ) + + if inputs_embeds is not None: + logger.warning( + "`inputs_embeds` is not `None` but `input_features` has been given." + "`input_features` will be used in priority through `speech_encoder`. " + "`inputs_embeds` will be ignored." + ) + + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4TForTextToText` and `SeamlessM4TForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + + self.set_modality("speech") + + encoder_outputs = self.speech_encoder( + input_features=input_features, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + elif input_ids is not None or inputs_embeds is not None: + # if encoder_outputs is not None, it's probably used within a .generate method so no need to warn + logger.warning( + "This calls the same method `forward` as `SeamlessM4TForTextToText` and `SeamlessM4TForSpeechToText`" + "depending on the input modality. If you want to generate speech, use the `generate` method." + ) + self.set_modality("text") + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + encoder_attention_mask = attention_mask + # input modality = speech so new attention mask + if self.current_modality == "speech" and attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_outputs[0].device + ) + encoder_attention_mask = _compute_new_attention_mask( + hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + lm_logits = self.lm_head(decoder_outputs[0]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + labels = labels.to(lm_logits.device) + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + outputs = decoder_outputs + encoder_outputs + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + input_features: Optional[torch.Tensor] = None, + return_intermediate_token_ids: Optional[bool] = None, + tgt_lang: Optional[str] = None, + spkr_id: Optional[int] = 0, + generate_speech: Optional[bool] = True, + **kwargs, + ) -> Union[torch.Tensor, SeamlessM4TGenerationOutput]: + """ + Generates translated token ids and/or translated audio waveforms. + + + + This method successively calls the `.generate` function of two different sub-models. You can specify keyword + arguments at two different levels: general arguments that will be passed to both models, or prefixed arguments + that will be passed to one of them. + + For example, calling `.generate(input_ids=input_ids, num_beams=4, speech_do_sample=True)` will successively + perform beam-search decoding on the text model, and multinomial beam-search sampling on the speech model. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SeamlessM4TTokenizer`] or [`SeamlessM4TProcessor`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_banks)`, *optional*): + Input audio features. This should be returnes by the [`SeamlessM4TFeatureExtractor`] class or the + [`SeamlessM4TProcessor`] class. See [`SeamlessM4TFeatureExtractor.__call__`] for details. + return_intermediate_token_ids (`bool`, *optional*): + If `True`, also returns the intermediate generated text and unit tokens. Set to `True` if you also want + to get translated text alongside the audio. Note that if `generate_speech=True`, this parameter will be + ignored. + tgt_lang (`str`, *optional*): + The language to use as target language for translation. + spkr_id (`int`, *optional*, defaults to 0): + The id of the speaker used for speech synthesis. Must be lower than `config.vocoder_num_spkrs`. + generate_speech (`bool`, *optional*, defaults to `True`): + If `False`, will only returns the text tokens and won't generate speech. + + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`GenerationMixin.generate`]. Keyword + arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model, + except for `decoder_input_ids` which will only be passed through the text components. + - With a *text_* or *speech_* prefix, they will be input for the `generate` method of the + text model and speech model respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for one generation but not for the + other. + + Returns: + `Union[SeamlessM4TGenerationOutput, Tuple[Tensor], ModelOutput]`: + - If `generate_speech` and `return_intermediate_token_ids`, returns [`SeamlessM4TGenerationOutput`]. + - If `generate_speech` and not `return_intermediate_token_ids`, returns a tuple composed of waveforms of + shape `(batch_size, sequence_length)`and and `waveform_lengths` which gives the length of each sample. + - If `generate_speech=False`, it will returns `ModelOutput`. + """ + if input_ids is None and input_features is None and kwargs.get("inputs_embeds", None) is None: + raise ValueError( + "`input_ids`,`input_features` and `inputs_embeds` are all empty. Make sure at least one of them is not." + ) + + if generate_speech and tgt_lang is None: + raise ValueError("You must specify a `tgt_lang` to generate translated speech.") + + if tgt_lang is not None: + # also accept __xxx__ + tgt_lang = tgt_lang.replace("__", "") + for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: + lang_code_to_id = getattr(self.generation_config, key, None) + if lang_code_to_id is None: + raise ValueError( + f"""This model generation config doesn't have a `{key}` key which maps the target language + to the right token id. Make sure to load the right generation config.""" + ) + elif tgt_lang not in lang_code_to_id: + raise ValueError( + f"""`tgt_lang={tgt_lang}` is not supported by this model. + Please specify a `tgt_lang` in {','.join(lang_code_to_id.keys())}. Note that SeamlessM4T supports + more languages for text translation than for speech synthesis.""" + ) + + batch_size = ( + len(input_features) + if input_features is not None + else (len(input_ids) if input_ids is not None else len(kwargs.get("inputs_embeds"))) + ) + + kwargs_text, kwargs_speech = format_speech_generation_kwargs(kwargs) + kwargs_text["output_hidden_states"] = True + kwargs_text["return_dict_in_generate"] = True + kwargs_text["output_scores"] = True + + text_decoder_input_ids = kwargs_text.get("decoder_input_ids") + # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + if tgt_lang is not None: + # tgt_lang gets priority over decoder input ids + text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + + kwargs_text["decoder_input_ids"] = text_decoder_input_ids + + # first generation + if input_features is not None: + self.set_modality("speech") + if input_ids is not None: + logger.warning( + "`input_features` and `input_ids` are both non empty. `input_features` will be used in priority " + "through the speech encoder. Make sure `input_features=None` if you want to use the text encoder." + ) + text_generation_output = super().generate(input_features=input_features, **kwargs_text) + else: + self.set_modality("text") + text_generation_output = super().generate(input_ids=input_ids, input_features=None, **kwargs_text) + sequences = text_generation_output.sequences + + if not generate_speech: + return text_generation_output + + # prepare second generation + num_return_sequences = len(sequences) // batch_size + attention_mask = kwargs_speech.get("attention_mask", kwargs_text.get("attention_mask", None)) + + # get encoder last hidden states + if self.current_modality == "speech": + # get last_hidden_state from encoder - must do a pass through the speech encoder + encoder_hidden_states = self.speech_encoder( + input_features=input_features, attention_mask=attention_mask + ).last_hidden_state + + # input modality = speech so new attention mask for the decoder + if attention_mask is not None: + sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(attention_mask).to( + encoder_hidden_states.device + ) + attention_mask = _compute_new_attention_mask( + hidden_states=encoder_hidden_states, seq_lens=sub_sampled_lengths + ) + else: + encoder_hidden_states = text_generation_output.encoder_hidden_states[-1] + + # take care of num_return_sequences + # take most probable hidden states per batch of return_sequences + # (batch_size*num_return_sequences, ...) -> (batch_size,...) + if num_return_sequences > 1: + idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) + idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) + idx_most_probable_sequences_per_batch = ( + idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences + ) + sequences = sequences[idx_most_probable_sequences_per_batch] + + # get decoder last hidden state - must do a pass through the text decoder + t2u_input_embeds = self.text_decoder( + input_ids=sequences, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + head_mask=kwargs_text.get("decoder_head_mask"), + cross_attn_head_mask=kwargs_text.get("cross_attn_head_mask"), + ).last_hidden_state + + pad_token_id = self.generation_config.pad_token_id + + # Compute new attention mask + seq_lens = (sequences != pad_token_id).int().sum(1) + t2u_model_attention_mask = _compute_new_attention_mask(t2u_input_embeds, seq_lens) + kwargs_speech["attention_mask"] = t2u_model_attention_mask + + # Compute t2u decoder_input_ids + t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") + t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) + t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to( + self.device + ) + kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids + + # second generation + unit_ids = self.t2u_model.generate(inputs_embeds=t2u_input_embeds, **kwargs_speech) + output_unit_ids = unit_ids.detach().clone() + + # get rid of t2u_decoder_input_ids + unit_ids = unit_ids[:, kwargs_speech["decoder_input_ids"].shape[1] :] + # replace eos per pad + unit_ids[unit_ids == self.config.t2u_eos_token_id] = self.config.t2u_pad_token_id + # offset of control symbols + unit_ids = torch.where( + unit_ids == self.config.t2u_pad_token_id, unit_ids, unit_ids - self.config.vocoder_offset + ) + + vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device) + + waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) + + if return_intermediate_token_ids: + return SeamlessM4TGenerationOutput( + waveform=waveform, + waveform_lengths=waveform_lengths, + sequences=sequences, + unit_sequences=output_unit_ids, + ) + + return waveform, waveform_lengths + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past diff --git a/src/transformers/models/seamless_m4t/processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/processing_seamless_m4t.py new file mode 100644 index 00000000000000..4f22e9e33d0c3a --- /dev/null +++ b/src/transformers/models/seamless_m4t/processing_seamless_m4t.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio/Text processor class for SeamlessM4T +""" + +from ...processing_utils import ProcessorMixin + + +class SeamlessM4TProcessor(ProcessorMixin): + r""" + Constructs a SeamlessM4T processor which wraps a SeamlessM4T feature extractor and a SeamlessM4T tokenizer into a + single processor. + + [`SeamlessM4TProcessor`] offers all the functionalities of [`SeamlessM4TFeatureExtractor`] and + [`SeamlessM4TTokenizerFast`]. See the [`~SeamlessM4TProcessor.__call__`] and [`~SeamlessM4TProcessor.decode`] for + more information. + + Args: + feature_extractor ([`SeamlessM4TFeatureExtractor`]): + The audio processor is a required input. + tokenizer ([`SeamlessM4TTokenizerFast`]): + The tokenizer is a required input. + """ + feature_extractor_class = "SeamlessM4TFeatureExtractor" + tokenizer_class = ("SeamlessM4TTokenizer", "SeamlessM4TTokenizerFast") + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__(self, text=None, audios=None, src_lang=None, tgt_lang=None, **kwargs): + """ + Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` + and `kwargs` arguments to SeamlessM4TTokenizerFast's [`~SeamlessM4TTokenizerFast.__call__`] if `text` is not + `None` to encode the text. To prepare the audio(s), this method forwards the `audios` and `kwrags` arguments to + SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audios` is not `None`. Please refer + to the doctsring of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case + of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, + and T the sample length of the audio. + src_lang (`str`, *optional*): + The language code of the input texts/audios. If not specified, the last `src_lang` specified will be + used. + tgt_lang (`str`, *optional*): + The code of the target language. If not specified, the last `tgt_lang` specified will be used. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the + tokenizer. + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **input_features** -- Audio input features to be fed to a model. Returned when `audios` is not `None`. + """ + sampling_rate = kwargs.pop("sampling_rate", None) + + if text is None and audios is None: + raise ValueError("You have to specify either text or audios. Both cannot be none.") + elif text is not None and audios is not None: + raise ValueError( + "Text and audios are mututally exclusive when passed to `SeamlessM4T`. Specify one or another." + ) + elif text is not None: + if tgt_lang is not None: + self.tokenizer.tgt_lang = tgt_lang + if src_lang is not None: + self.tokenizer.src_lang = src_lang + encoding = self.tokenizer(text, **kwargs) + + return encoding + + else: + encoding = self.feature_extractor(audios, sampling_rate=sampling_rate, **kwargs) + return encoding + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SeamlessM4TTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. + Please refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to SeamlessM4TTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names)) diff --git a/src/transformers/models/seamless_m4t/tokenization_seamless_m4t.py b/src/transformers/models/seamless_m4t/tokenization_seamless_m4t.py new file mode 100644 index 00000000000000..2daeb794b86543 --- /dev/null +++ b/src/transformers/models/seamless_m4t/tokenization_seamless_m4t.py @@ -0,0 +1,565 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for SeamlessM4T.""" +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import ( + BatchEncoding, + PreTokenizedInput, + PreTrainedTokenizer, + TextInput, +) +from ...tokenization_utils_base import AddedToken +from ...utils import PaddingStrategy, logging + + +logger = logging.get_logger(__name__) + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "facebook/hf-seamless-m4t-medium": ( + "https://huggingface.co/facebook/hf-seamless-m4t-medium/blob/main/sentencepiece.bpe.model" + ), + } +} + +SPIECE_UNDERLINE = "▁" + + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "facebook/hf-seamless-m4t-medium": 2048, +} + + +class SeamlessM4TTokenizer(PreTrainedTokenizer): + """ + Construct a SeamlessM4T tokenizer. + + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + The tokenization method is ` ` for source language documents, and ` ` for target language documents. + + Examples: + + ```python + >>> from transformers import SeamlessM4TTokenizer + + >>> tokenizer = SeamlessM4TTokenizer.from_pretrained( + ... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + src_lang (`str`, *optional*, defaults to `"eng"`): + The language to use as source language for translation. + tgt_lang (`str`, *optional*, defaults to `"fra"`): + The language to use as target language for translation. + sp_model_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the model initialization. + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional special tokens. Can be used to specify the list of languages that will be + supported by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + tokenizer_file=None, + src_lang="eng", + tgt_lang="fra", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + additional_special_tokens=None, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + # Add this unused argument to keep some important Copied from statements + self.legacy = False + self.vocab_file = vocab_file + + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ---- + # spm | '' | '' | '' | 'an' | 'en' | '_d' | 'er' | 'in' | '_s' | '_a' + # fairseq | '' | '' | '' | '' | 'an' | 'en' | '▁d' | 'er' | 'in' | '▁s' + + # Mimic fairseq token-to-id alignment for the first 4 token + self._added_tokens_decoder = { + 0: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token, + 1: AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token, + 2: AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token, + 3: AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token, + } + + # The first "real" token "an" has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.sp_model_size = len(self.sp_model) + + self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang + self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + self.set_src_lang_special_tokens(self._src_lang) + self.set_tgt_lang_special_tokens(self._tgt_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.__setstate__ + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + return len(self.sp_model) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + **kwargs, + ): + """ + Args: + text (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + src_lang (`str`, *optional*): + A string representing the source language. If not specified, the last `src_lang` specified (either + during initialization or when calling this tokenizer) will be used. + tgt_lang (`str`, *optional*): + A string representing the target language. If not specified, the last `tgt_lang` specified (either + during initialization or when calling this tokenizer) will be used. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizer.__call__`]. + """ + if src_lang is not None: + self.src_lang = src_lang + if tgt_lang is not None: + self.tgt_lang = tgt_lang + + output = super().__call__( + text=text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + padding=padding, + pad_to_multiple_of=pad_to_multiple_of, + **kwargs, + ) + + return BatchEncoding(output, tensor_type=kwargs.get("return_tensors")) + + @property + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + if "__" not in new_src_lang: + self._src_lang = f"__{new_src_lang}__" + else: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang: str) -> None: + if "__" not in new_tgt_lang: + self._tgt_lang = f"__{new_tgt_lang}__" + else: + self._tgt_lang = new_tgt_lang + self.set_tgt_lang_special_tokens(self._tgt_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model.") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + if "__" not in tgt_lang: + tgt_lang = f"__{tgt_lang}__" + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def get_vocab(self): + vocab = { + self.convert_ids_to_tokens(i): i for i in range(self.fairseq_offset, self.vocab_size + self.fairseq_offset) + } + vocab.update(self.added_tokens_encoder) + return vocab + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + if tokens[0].startswith(SPIECE_UNDERLINE): + tokens[0] = tokens[0][1:] + + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.prepare_seq2seq_batch with eng_Latn->eng, fra_Latn->fra + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "eng", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "fra", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_input_mode + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer._switch_to_target_mode + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + self.init_kwargs["src_lang"] = src_lang + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`src_lang={src_lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + # https://github.com/facebookresearch/fairseq2/blob/c53f18e6be6b8b46b722f2249b8397b7eccd7ad3/src/fairseq2/models/nllb/tokenizer.py#L112-L116 + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + Prefix=[eos, tgt_lang_code] and suffix=[eos]. + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + self.init_kwargs["tgt_lang"] = lang + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={lang}` has not be found in the vocabulary. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.prefix_tokens = [self.eos_token_id, self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] diff --git a/src/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py b/src/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py new file mode 100644 index 00000000000000..8ca03ac6747bb7 --- /dev/null +++ b/src/transformers/models/seamless_m4t/tokenization_seamless_m4t_fast.py @@ -0,0 +1,459 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization class for SeamlessM4T.""" +import os +from shutil import copyfile +from typing import List, Optional, Tuple, Union + +from tokenizers import processors + +from ...tokenization_utils import ( + BatchEncoding, + PreTokenizedInput, + TextInput, +) +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import PaddingStrategy, logging +from .tokenization_seamless_m4t import ( + SeamlessM4TTokenizer, +) + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "facebook/hf-seamless-m4t-medium": "https://huggingface.co/facebook/hf-seamless-m4t-medium/resolve/main/vocab.txt", + }, + "tokenizer_file": { + "facebook/hf-seamless-m4t-medium": "https://huggingface.co/facebook/hf-seamless-m4t-medium/resolve/main/tokenizer.json", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "facebook/hf-seamless-m4t-medium": 2048, +} + + +class SeamlessM4TTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" SeamlessM4T tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + The tokenization method is ` ` for source language documents, and ` ` for target language documents. + + Examples: + + ```python + >>> from transformers import SeamlessM4TTokenizerFast + + >>> tokenizer = SeamlessM4TTokenizerFast.from_pretrained( + ... "facebook/hf-seamless-m4t-medium", src_lang="eng", tgt_lang="fra" + ... ) + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie." + >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt") + ``` + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + tokenizer_file (`str`, *optional*): + The path to a tokenizer file to use instead of the vocab file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + src_lang (`str`, *optional*, defaults to `"eng"`): + The language to use as source language for translation. + tgt_lang (`str`, *optional*, defaults to `"fra"`): + The language to use as target language for translation. + additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*): + A tuple or a list of additional special tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = SeamlessM4TTokenizer + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + src_lang="eng", + tgt_lang="fra", + additional_special_tokens=None, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + self._src_lang = f"__{src_lang}__" if "__" not in src_lang else src_lang + self._tgt_lang = f"__{tgt_lang}__" if "__" not in tgt_lang else tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + self.set_tgt_lang_special_tokens(self._tgt_lang) + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + @property + # Copied from transformers.models.nllb.tokenization_nllb.NllbTokenizer.src_lang + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + if "__" not in new_src_lang: + self._src_lang = f"__{new_src_lang}__" + else: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang: str) -> None: + if "__" not in new_tgt_lang: + self._tgt_lang = f"__{new_tgt_lang}__" + else: + self._tgt_lang = new_tgt_lang + self.set_tgt_lang_special_tokens(self._tgt_lang) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. The special tokens depend on calling set_lang. + + An SeamlessM4T sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `[src_lang_code] X [eos]` + - `decoder_input_ids`: (for decoder) `[eos, tgt_lang_code] X [eos]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + if "__" not in tgt_lang: + tgt_lang = f"__{tgt_lang}__" + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.prepare_seq2seq_batch with "fra_Latn"->"fra", "eng_Latn"->"eng" + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "eng", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "fra", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_input_mode + def _switch_to_input_mode(self): + return self.set_src_lang_special_tokens(self.src_lang) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast._switch_to_target_mode + def _switch_to_target_mode(self): + return self.set_tgt_lang_special_tokens(self.tgt_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. + Prefix=[src_lang_code], suffix = [eos] + """ + self.cur_lang_code = self.convert_tokens_to_ids(src_lang) + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={src_lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.init_kwargs["src_lang"] = src_lang + + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target lang setting. + Prefix=[eos, tgt_lang_code] and suffix=[eos]. + """ + self.cur_lang_code = self.convert_tokens_to_ids(lang) + + if self.cur_lang_code == self.unk_token_id: + logger.warning_once( + f"`tgt_lang={lang}` has not be found in the `vocabulary`. Behaviour will probably be unexpected because the language token id will be replaced by the unknown token id." + ) + + self.init_kwargs["tgt_lang"] = lang + + self.prefix_tokens = [self.eos_token_id, self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + # Copied from transformers.models.nllb.tokenization_nllb_fast.NllbTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory.") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + @classmethod + def _from_pretrained( + cls, + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=None, + cache_dir=None, + local_files_only=False, + _commit_hash=None, + _is_local=False, + **kwargs, + ): + tokenizer = super()._from_pretrained( + resolved_vocab_files, + pretrained_model_name_or_path, + init_configuration, + *init_inputs, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + _commit_hash=_commit_hash, + _is_local=_is_local, + **kwargs, + ) + + # ensure also set after from pretrained + tokenizer.set_src_lang_special_tokens(tokenizer._src_lang) + tokenizer.set_tgt_lang_special_tokens(tokenizer._tgt_lang) + + return tokenizer + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + padding: Union[bool, str, PaddingStrategy] = True, + pad_to_multiple_of: Optional[int] = 2, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + **kwargs, + ): + """ + Args: + text (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*): + The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a + list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), + you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + src_lang (`str`, *optional*): + A string representing the source language. If not specified, the last `src_lang` specified (either + during initialization or when calling this tokenizer) will be used. + tgt_lang (`str`, *optional*): + A string representing the target language. If not specified, the last `tgt_lang` specified (either + during initialization or when calling this tokenizer) will be used. + kwargs (*optional*): + Remaining dictionary of keyword arguments that will be passed to [`PreTrainedTokenizerFast.__call__`]. + """ + if src_lang is not None: + self.src_lang = src_lang + if tgt_lang is not None: + self.tgt_lang = tgt_lang + + output = super().__call__( + text=text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + padding=padding, + pad_to_multiple_of=pad_to_multiple_of, + **kwargs, + ) + + return output diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2ab62792372d76..18a2abb11f135b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6933,6 +6933,79 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class SeamlessM4TCodeHifiGan(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4TForSpeechToSpeech(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4TForSpeechToText(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4TForTextToSpeech(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4TForTextToText(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4THifiGan(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4TModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4TPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4TTextToUnitForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SeamlessM4TTextToUnitModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index 32bf223d57229b..658645746329c6 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -177,6 +177,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["sentencepiece"]) +class SeamlessM4TTokenizer(metaclass=DummyObject): + _backends = ["sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["sentencepiece"]) + + class Speech2TextTokenizer(metaclass=DummyObject): _backends = ["sentencepiece"] diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 07234f74dd0a89..b8cc21303a815a 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -366,6 +366,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) +class SeamlessM4TTokenizerFast(metaclass=DummyObject): + _backends = ["tokenizers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tokenizers"]) + + class SplinterTokenizerFast(metaclass=DummyObject): _backends = ["tokenizers"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 175861fd149e5e..86a3d5efd90b6d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1588,7 +1588,7 @@ def test_assisted_decoding_sample(self): # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] + for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] ): return diff --git a/tests/models/seamless_m4t/__init__.py b/tests/models/seamless_m4t/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py new file mode 100644 index 00000000000000..2b55a61d812ad8 --- /dev/null +++ b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py @@ -0,0 +1,223 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import os +import random +import tempfile +import unittest + +import numpy as np +from datasets import load_dataset + +from transformers import SeamlessM4TFeatureExtractor, is_speech_available +from transformers.testing_utils import check_json_file_has_correct_format, require_torch +from transformers.utils.import_utils import is_torch_available + +from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin + + +if is_torch_available(): + import torch + +global_rng = random.Random() + + +# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list +def floats_list(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + values = [] + for batch_idx in range(shape[0]): + values.append([]) + for _ in range(shape[1]): + values[-1].append(rng.random() * scale) + + return values + + +@require_torch +class SeamlessM4TFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + min_seq_length=400, + max_seq_length=2000, + feature_size=10, + padding_value=0.0, + sampling_rate=4_000, + return_attention_mask=True, + do_normalize=True, + stride=2, + ): + self.parent = parent + self.batch_size = batch_size + self.min_seq_length = min_seq_length + self.max_seq_length = max_seq_length + self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1) + self.padding_value = padding_value + self.sampling_rate = sampling_rate + self.return_attention_mask = return_attention_mask + self.do_normalize = do_normalize + self.feature_size = feature_size + self.stride = stride + self.num_mel_bins = feature_size + + def prepare_feat_extract_dict(self): + return { + "feature_size": self.feature_size, + "num_mel_bins": self.num_mel_bins, + "padding_value": self.padding_value, + "sampling_rate": self.sampling_rate, + "stride": self.stride, + "return_attention_mask": self.return_attention_mask, + "do_normalize": self.do_normalize, + } + + # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTester.prepare_inputs_for_common + def prepare_inputs_for_common(self, equal_length=False, numpify=False): + def _flatten(list_of_lists): + return list(itertools.chain(*list_of_lists)) + + if equal_length: + speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)] + else: + # make sure that inputs increase in size + speech_inputs = [ + floats_list((x, self.feature_size)) + for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff) + ] + if numpify: + speech_inputs = [np.asarray(x) for x in speech_inputs] + return speech_inputs + + +@require_torch +class SeamlessM4TFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): + feature_extraction_class = SeamlessM4TFeatureExtractor if is_speech_available() else None + + def setUp(self): + self.feat_extract_tester = SeamlessM4TFeatureExtractionTester(self) + + def test_feat_extract_from_and_save_pretrained(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + saved_file = feat_extract_first.save_pretrained(tmpdirname)[0] + check_json_file_has_correct_format(saved_file) + feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + self.assertDictEqual(dict_first, dict_second) + + def test_feat_extract_to_json_file(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + json_file_path = os.path.join(tmpdirname, "feat_extract.json") + feat_extract_first.to_json_file(json_file_path) + feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + self.assertEqual(dict_first, dict_second) + + def test_call(self): + # Tests that all call wrap to encode_plus and batch_encode_plus + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + # create three inputs of length 800, 1000, and 1200 + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs] + + # Test feature size + input_features = feature_extractor(np_speech_inputs, padding=True, return_tensors="np").input_features + self.assertTrue(input_features.ndim == 3) + self.assertTrue(input_features.shape[0] == 3) + self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size * feature_extractor.stride) + + # Test not batched input + encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features + self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + # Test batched + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + + @require_torch + # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad + def test_double_precision_pad(self): + import torch + + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + np_speech_inputs = np.random.rand(100, 32).astype(np.float64) + py_speech_inputs = np_speech_inputs.tolist() + + for inputs in [py_speech_inputs, np_speech_inputs]: + np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np") + self.assertTrue(np_processed.input_features.dtype == np.float32) + pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt") + self.assertTrue(pt_processed.input_features.dtype == torch.float32) + + def _load_datasample(self, id): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # automatic decoding with librispeech + speech_sample = ds.sort("id")[id]["audio"]["array"] + + return torch.from_numpy(speech_sample).unsqueeze(0) + + def test_integration(self): + # fmt: off + EXPECTED_INPUT_FEATURES = torch.tensor( + [ + -1.5621, -1.4236, -1.3335, -1.3991, -1.2881, -1.1133, -0.9710, -0.8895, + -0.8280, -0.7376, -0.7194, -0.6896, -0.6849, -0.6788, -0.6545, -0.6610, + -0.6566, -0.5738, -0.5252, -0.5533, -0.5887, -0.6116, -0.5971, -0.4956, + -0.2881, -0.1512, 0.0299, 0.1762, 0.2728, 0.2236 + ] + ) + # fmt: on + + input_speech = self._load_datasample(10) + feature_extractor = SeamlessM4TFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="pt").input_features + + feature_extractor(input_speech, return_tensors="pt").input_features[0, 5, :30] + self.assertEqual(input_features.shape, (1, 279, 160)) + self.assertTrue(torch.allclose(input_features[0, 5, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) + + def test_zero_mean_unit_variance_normalization_trunc_np_longest(self): + feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + audio = self._load_datasample(1) + audio = ((audio - audio.min()) / (audio.max() - audio.min())) * 65535 # Rescale to [0, 65535] to show issue + audio = feat_extract.zero_mean_unit_var_norm([audio], attention_mask=None)[0] + + self.assertTrue((audio.mean() < 1e-3).all()) + self.assertTrue(((audio.var() - 1).abs() < 1e-3).all()) diff --git a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py new file mode 100644 index 00000000000000..d75629efc35115 --- /dev/null +++ b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -0,0 +1,1110 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch SeamlessM4T model. """ + + +import copy +import inspect +import tempfile +import unittest + +from transformers import SeamlessM4TConfig, is_speech_available, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.trainer_utils import set_seed +from transformers.utils import cached_property + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + random_attention_mask, +) + + +if is_torch_available(): + import torch + + from transformers import ( + SeamlessM4TForSpeechToSpeech, + SeamlessM4TForSpeechToText, + SeamlessM4TForTextToSpeech, + SeamlessM4TForTextToText, + SeamlessM4TModel, + ) + from transformers.models.seamless_m4t.modeling_seamless_m4t import ( + SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST, + ) + +if is_speech_available(): + from transformers import SeamlessM4TProcessor + + +class SeamlessM4TModelTester: + def __init__( + self, + parent, + input_modality="speech", + batch_size=2, + seq_length=4, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + initializer_range=0.02, + max_new_tokens=None, + num_labels=3, + num_choices=4, + scope=None, + vocab_size=20, + t2u_vocab_size=20, + hidden_size=6, + num_hidden_layers=2, + intermediate_size=6, + max_position_embeddings=256, + encoder_layers=2, + decoder_layers=2, + encoder_ffn_dim=6, + decoder_ffn_dim=6, + t2u_encoder_layers=2, + t2u_decoder_layers=2, + t2u_encoder_ffn_dim=6, + t2u_decoder_ffn_dim=6, + num_heads=2, + vocoder_num_spkrs=5, + vocoder_num_langs=5, + upsample_initial_channel=32, + unit_embed_dim=25, + spkr_embed_dim=6, + lang_embed_dim=6, + num_conv_pos_embeddings=8, + unit_hifi_gan_vocab_size=20, + t2u_num_langs=0, + t2u_max_new_tokens=25, + t2u_offset_tgt_lang=0, + vocoder_offset=0, + ): + self.parent = parent + self.input_modality = input_modality + + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + self.vocab_size = vocab_size + self.t2u_vocab_size = t2u_vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.decoder_ffn_dim = decoder_ffn_dim + self.t2u_encoder_layers = t2u_encoder_layers + self.t2u_decoder_layers = t2u_decoder_layers + self.t2u_encoder_ffn_dim = t2u_encoder_ffn_dim + self.t2u_decoder_ffn_dim = t2u_decoder_ffn_dim + self.num_heads = num_heads + self.num_attention_heads = num_heads + + self.vocoder_num_spkrs = vocoder_num_spkrs + self.vocoder_num_langs = vocoder_num_langs + self.upsample_initial_channel = upsample_initial_channel + self.unit_embed_dim = unit_embed_dim + self.spkr_embed_dim = spkr_embed_dim + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.lang_embed_dim = lang_embed_dim + + self.max_new_tokens = max_new_tokens + + self.unit_hifi_gan_vocab_size = unit_hifi_gan_vocab_size + self.t2u_num_langs = t2u_num_langs + self.t2u_max_new_tokens = t2u_max_new_tokens + self.t2u_offset_tgt_lang = t2u_offset_tgt_lang + self.vocoder_offset = vocoder_offset + + def prepare_config_and_inputs(self): + if self.input_modality == "text": + inputs = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + else: + inputs = ids_tensor([self.batch_size, self.seq_length, 160], self.vocab_size - 1).float() + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + + lm_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, inputs, decoder_input_ids, input_mask, lm_labels + + def get_config(self): + return SeamlessM4TConfig( + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + initializer_range=self.initializer_range, + vocab_size=self.vocab_size, + t2u_vocab_size=self.t2u_vocab_size, + hidden_size=self.hidden_size, + speech_encoder_layers=self.num_heads, + speech_encoder_intermediate_size=self.intermediate_size, + max_position_embeddings=self.max_position_embeddings, + encoder_layers=self.encoder_layers, + decoder_layers=self.decoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + decoder_ffn_dim=self.decoder_ffn_dim, + t2u_encoder_layers=self.t2u_encoder_layers, + t2u_decoder_layers=self.t2u_decoder_layers, + t2u_encoder_ffn_dim=self.t2u_encoder_ffn_dim, + t2u_decoder_ffn_dim=self.t2u_decoder_ffn_dim, + num_attention_heads=self.num_heads, + encoder_attention_heads=self.num_heads, + decoder_attention_heads=self.num_heads, + t2u_encoder_attention_heads=self.num_heads, + t2u_decoder_attention_heads=self.num_heads, + speech_encoder_attention_heads=self.num_heads, + unit_hifigan_vocab_vise=self.t2u_vocab_size, + vocoder_num_spkrs=self.vocoder_num_spkrs, + vocoder_num_langs=self.vocoder_num_langs, + upsample_initial_channel=self.upsample_initial_channel, + unit_embed_dim=self.unit_embed_dim, + spkr_embed_dim=self.spkr_embed_dim, + num_conv_pos_embeddings=self.num_conv_pos_embeddings, + lang_embed_dim=self.lang_embed_dim, + max_new_tokens=self.max_new_tokens, + unit_hifi_gan_vocab_size=self.unit_hifi_gan_vocab_size, + t2u_num_langs=self.t2u_num_langs, + t2u_max_new_tokens=self.t2u_max_new_tokens, + t2u_offset_tgt_lang=self.t2u_offset_tgt_lang, + vocoder_offset=self.vocoder_offset, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_model(self, config, input_ids, decoder_input_ids, input_mask, labels): + model = SeamlessM4TModel(config=config) + model.to(torch_device) + model.eval() + if self.input_modality == "text": + result = model(input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids) + result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + else: + result = model(input_features=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids) + result = model(input_features=input_ids, decoder_input_ids=decoder_input_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + decoder_output = result.logits + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + if self.input_modality == "text": + seq_length = self.seq_length + else: + # if speech, expected length has been subsampled. + seq_length = model._compute_sub_sample_lengths_from_attention_mask(input_mask).max().item() + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, seq_length, self.hidden_size)) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, decoder_input_ids.shape[1], self.vocab_size)) + # There should be `num_layers` key value embeddings stored in decoder_past + self.parent.assertEqual(len(decoder_past), config.decoder_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple + self.parent.assertEqual(len(decoder_past[0]), 4) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + model = SeamlessM4TModel(config=config) + model.to(torch_device) + model.eval() + + # make sure no pad token in decoder_input_ids + decoder_input_ids = torch.clamp(decoder_input_ids, config.pad_token_id + 1) + + # first forward pass + outputs = model( + input_ids, decoder_input_ids=decoder_input_ids, decoder_attention_mask=input_mask, use_cache=True + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + input_ids, + decoder_input_ids=next_input_ids, + decoder_attention_mask=next_attention_mask, + output_hidden_states=True, + ) + output_from_no_past = output_from_no_past["decoder_hidden_states"][0] + output_from_past = model( + input_ids, + decoder_input_ids=next_tokens, + decoder_attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["decoder_hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + input_mask, + lm_labels, + ) = config_and_inputs + + input_name = "input_ids" if self.input_modality == "text" else "input_features" + + inputs_dict = { + input_name: input_ids, + "attention_mask": input_mask, + "decoder_input_ids": decoder_input_ids, + "labels": lm_labels, + } + return config, inputs_dict + + +@require_torch +class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase): + is_encoder_decoder = True + fx_compatible = False + test_missing_keys = False + test_pruning = False + test_model_parallel = False + test_resize_embeddings = False + test_headmasking = False + test_torchscript = False + + all_model_classes = ( + ( + SeamlessM4TModel, + SeamlessM4TForSpeechToSpeech, + SeamlessM4TForSpeechToText, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (SeamlessM4TForSpeechToText,) if is_torch_available() else () + + input_name = "input_features" + + def setUp(self): + self.model_tester = SeamlessM4TModelTester(self, input_modality="speech") + self.config_tester = ConfigTester(self, config_class=SeamlessM4TConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = SeamlessM4TModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def _get_input_ids_and_config(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict[self.input_name] + + # cut to half length & take max batch_size 3 + sequence_length = input_ids.shape[-1] // 2 + input_ids = input_ids[:batch_size, :sequence_length] + + # generate max 3 tokens + max_length = input_ids.shape[-1] + 3 + if config.eos_token_id is not None and config.pad_token_id is None: + # hack to allow generate for models such as GPT2 as is done in `generate()` + if isinstance(config.eos_token_id, int): + config.eos_token_id = [config.eos_token_id] + config.pad_token_id = config.eos_token_id[0] + + attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.long)[:batch_size, :sequence_length] + + return config, input_ids.float(), attention_mask, max_length + + @staticmethod + def _get_encoder_outputs( + model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 + ): + encoder = model.get_encoder() + encoder_outputs = encoder( + input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( + num_interleave, dim=0 + ) + input_ids = ( + torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device) + + model._get_decoder_start_token_id() + ) + attention_mask = None + return encoder_outputs, input_ids, attention_mask + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = [ + "conv.weight", + "masked_spec_embed", + "codevectors", + "quantizer.weight_proj.weight", + "project_hid.weight", + "project_hid.bias", + "project_q.weight", + "project_q.bias", + "pos_bias_v", + "pos_bias_u", + "pointwise_conv1", + "pointwise_conv2", + "feature_projection.projection.weight", + "feature_projection.projection.bias", + "objective.weight", + "adapter", + ] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer") + def test_inputs_embeds(self): + pass + + @unittest.skip( + reason="Expected missing keys serve when using SeamlessM4TForXXX.from_pretrained from a checkpoint saved by SeamlessM4TModel.save_pretrained." + ) + def test_model_weights_reload_no_missing_tied_weights(self): + pass + + @unittest.skip( + reason="SeamlessM4TModel is base class but has actually a bigger architecture than seamlessM4T task-specific models." + ) + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="The speech encoder doesn't support head masking") + def test_generate_with_head_masking(self): + pass + + @unittest.skip(reason="SeamlessM4TModel can takes input_ids or input_features") + def test_forward_signature(self): + pass + + @unittest.skip(reason="SeamlessM4T has no base model") + def test_save_load_fast_init_from_base(self): + pass + + def test_attention_outputs(self): + # expected length is subsampled so need to change a bit this test + if not self.has_attentions: + self.skipTest(reason="Model does not output attentions") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + # no more chunk_length test + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + if self.is_encoder_decoder: + correct_outlen = 5 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + if "past_key_values" in outputs: + correct_outlen += 1 # past_key_values have been returned + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + + sub_sampled_length = ( + model._compute_sub_sample_lengths_from_attention_mask(inputs_dict["attention_mask"]).max().item() + ) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + sub_sampled_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + +@require_torch +class SeamlessM4TModelWithTextInputTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + is_encoder_decoder = True + fx_compatible = False + test_missing_keys = False + test_pruning = False + test_model_parallel = False + test_resize_embeddings = True + test_headmasking = False + test_torchscript = False + + all_model_classes = ( + ( + SeamlessM4TModel, + SeamlessM4TForTextToSpeech, + SeamlessM4TForTextToText, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (SeamlessM4TForTextToText,) if is_torch_available() else () + + def setUp(self): + self.model_tester = SeamlessM4TModelTester(self, input_modality="text") + self.config_tester = ConfigTester(self, config_class=SeamlessM4TConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in SEAMLESS_M4T_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = SeamlessM4TModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = [ + "conv.weight", + "masked_spec_embed", + "codevectors", + "quantizer.weight_proj.weight", + "project_hid.weight", + "project_hid.bias", + "project_q.weight", + "project_q.bias", + "pos_bias_v", + "pos_bias_u", + "pointwise_conv1", + "pointwise_conv2", + "feature_projection.projection.weight", + "feature_projection.projection.bias", + "objective.weight", + "adapter", + ] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @unittest.skip( + reason="Expected missing keys serve when using SeamlessM4TForXXX.from_pretrained from a checkpoint saved by SeamlessM4TModel.save_pretrained." + ) + def test_model_weights_reload_no_missing_tied_weights(self): + pass + + def test_generate_with_head_masking(self): + """Test designed for encoder-decoder models to ensure the attention head masking is used.""" + attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + model = model_class(config).to(torch_device).eval() + + head_masking = { + "head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads, device=torch_device), + "decoder_head_mask": torch.zeros( + config.decoder_layers, config.decoder_attention_heads, device=torch_device + ), + "cross_attn_head_mask": torch.zeros( + config.decoder_layers, config.decoder_attention_heads, device=torch_device + ), + } + + signature = inspect.signature(model.forward) + # We want to test only models where encoder/decoder head masking is implemented + if not set(head_masking.keys()) < {*signature.parameters.keys()}: + continue + + for attn_name, (name, mask) in zip(attention_names, head_masking.items()): + out = model.generate( + input_ids, + attention_mask=attention_mask, + num_beams=1, + output_attentions=True, + return_dict_in_generate=True, + remove_invalid_values=True, + **{name: mask}, + ) + # We check the state of decoder_attentions and cross_attentions just from the last step + attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] + self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + + @unittest.skip(reason="SeamlessM4TModel can take input_ids or input_features") + def test_forward_signature(self): + pass + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + @unittest.skip( + reason="SeamlessM4TModel is base class but has actually a bigger architecture than seamlessM4T task-specific models." + ) + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="SeamlessM4T has no base model") + def test_save_load_fast_init_from_base(self): + pass + + +@require_torch +class SeamlessM4TGenerationTest(unittest.TestCase): + # test that non-standard generation works + # test generation of: SeamlessM4TModel, SeamlessM4TForSpeechToSpeech, SeamlessM4TForSpeechToText, SeamlessM4TForTextToSpeech + + def setUp(self): + self.speech_model_tester = SeamlessM4TModelTester(self, input_modality="speech") + self.text_model_tester = SeamlessM4TModelTester(self, input_modality="text") + self.tmpdirname = tempfile.mkdtemp() + + def update_generation(self, model): + lang_code_to_id = { + "fra": 4, + "eng": 4, + } + + generation_config = copy.deepcopy(model.generation_config) + + generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id) + generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id) + generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id) + + generation_config._from_model_config = False + + model.generation_config = generation_config + + def prepare_text_input(self): + config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs() + + input_dict = { + "input_ids": inputs, + "attention_mask": input_mask, + "tgt_lang": "eng", + "num_beams": 2, + "do_sample": True, + } + + return config, input_dict + + def prepare_speech_input(self): + config, inputs, decoder_input_ids, input_mask, lm_labels = self.speech_model_tester.prepare_config_and_inputs() + + input_dict = { + "input_features": inputs, + "attention_mask": input_mask, + "tgt_lang": "fra", + "num_beams": 2, + "do_sample": True, + } + + return config, input_dict + + def prepare_speech_and_text_input(self): + config, inputs, decoder_input_ids, input_mask, lm_labels = self.speech_model_tester.prepare_config_and_inputs() + + input_speech = { + "input_features": inputs, + "attention_mask": input_mask, + "tgt_lang": "fra", + "num_beams": 2, + "do_sample": True, + } + + config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs() + + input_text = { + "input_ids": inputs, + "attention_mask": input_mask, + "tgt_lang": "eng", + "num_beams": 2, + "do_sample": True, + } + return config, input_speech, input_text + + def factory_generation_speech_test(self, model, inputs): + set_seed(0) + output = model.generate(**inputs) + return output + + def test_speech_generation(self): + config, input_speech, input_text = self.prepare_speech_and_text_input() + + model = SeamlessM4TModel(config=config) + self.update_generation(model) + model.save_pretrained(self.tmpdirname) + model.to(torch_device) + model.eval() + + output_original_text = self.factory_generation_speech_test(model, input_text) + output_original_speech = self.factory_generation_speech_test(model, input_speech) + + model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname) + self.update_generation(model) + model.to(torch_device) + model.eval() + + output_text = self.factory_generation_speech_test(model, input_text) + + model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname) + self.update_generation(model) + model.to(torch_device) + model.eval() + + output_speech = self.factory_generation_speech_test(model, input_speech) + + # test same text output from input text + self.assertListEqual(output_original_text[0].ravel().tolist(), output_text[0].ravel().tolist()) + self.assertListEqual(output_original_text[1].ravel().tolist(), output_text[1].ravel().tolist()) + + # test same speech output from input text + self.assertListEqual(output_original_speech[0].ravel().tolist(), output_speech[0].ravel().tolist()) + self.assertListEqual(output_original_speech[1].ravel().tolist(), output_speech[1].ravel().tolist()) + + def test_text_generation(self): + config, input_speech, input_text = self.prepare_speech_and_text_input() + + # to return speech + input_speech["generate_speech"] = False + input_text["generate_speech"] = False + + model = SeamlessM4TModel(config=config) + self.update_generation(model) + model.save_pretrained(self.tmpdirname) + model.to(torch_device) + model.eval() + + output_original_text = self.factory_generation_speech_test(model, input_text) + output_original_speech = self.factory_generation_speech_test(model, input_speech) + + # other models don't need it + input_speech.pop("generate_speech") + input_text.pop("generate_speech") + + model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname) + self.update_generation(model) + model.to(torch_device) + model.eval() + + output_text = self.factory_generation_speech_test(model, input_text) + + model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname) + self.update_generation(model) + model.to(torch_device) + model.eval() + + output_speech = self.factory_generation_speech_test(model, input_speech) + + # test same text output from input text + self.assertListEqual(output_original_text[0].ravel().tolist(), output_text.ravel().tolist()) + + # test same speech output from input text + self.assertListEqual(output_original_speech[0].ravel().tolist(), output_speech.ravel().tolist()) + + def test_generation(self): + config, input_speech, input_text = self.prepare_speech_and_text_input() + + input_speech["num_beams"] = 3 + input_speech["do_sample"] = True + input_speech["num_return_sequences"] = 3 + + input_text["num_beams"] = 3 + input_text["do_sample"] = True + input_text["num_return_sequences"] = 3 + + for model_class in [SeamlessM4TForSpeechToSpeech, SeamlessM4TForSpeechToText, SeamlessM4TModel]: + model = model_class(config=config) + self.update_generation(model) + model.to(torch_device) + model.eval() + + output = model.generate(**input_speech) + output = output[0] if isinstance(output, tuple) else output + + self.assertEqual(output.shape[0], 3 * input_speech["input_features"].shape[0]) + + for model_class in [SeamlessM4TForTextToSpeech, SeamlessM4TForTextToText, SeamlessM4TModel]: + model = model_class(config=config) + self.update_generation(model) + model.to(torch_device) + model.eval() + + output = model.generate(**input_text) + + output = output[0] if isinstance(output, tuple) else output + + self.assertEqual(output.shape[0], 3 * input_text["input_ids"].shape[0]) + + +@require_torch +class SeamlessM4TModelIntegrationTest(unittest.TestCase): + repo_id = "facebook/hf-seamless-m4t-medium" + + def assertListAlmostEqual(self, list1, list2, tol=1e-3): + self.assertEqual(len(list1), len(list2)) + for a, b in zip(list1, list2): + self.assertAlmostEqual(a, b, delta=tol) + + @cached_property + def processor(self): + return SeamlessM4TProcessor.from_pretrained(self.repo_id) + + @cached_property + def input_text(self): + # corresponds to "C'est un test." with seamlessM4T_medium checkpoint + + # fmt: off + input_ids = torch.tensor([[256057, 152, 248116, 354, 159, 7356, 248075, 3]]) + # fmt: on + + input_ids = input_ids.to(torch_device) + + attention_mask = torch.ones_like(input_ids).to(torch_device) + + inputs = { + "attention_mask": attention_mask, + "input_ids": input_ids, + } + + return inputs + + @cached_property + def input_audio(self): + set_seed(0) + seq_len = 20000 + sampling_rate = 16000 + input_features = torch.rand((2, seq_len)) + + return self.processor(audios=[input_features.tolist()], sampling_rate=sampling_rate, return_tensors="pt").to( + torch_device + ) + + def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs): + model1 = class1.from_pretrained(self.repo_id).to(torch_device) + model2 = class2.from_pretrained(self.repo_id).to(torch_device) + + set_seed(0) + output_1 = model1.generate(**inputs, **class1_kwargs) + set_seed(0) + output_2 = model2.generate(**inputs, **class2_kwargs) + + for key in output_1: + if isinstance(output_1[key], torch.Tensor): + if len(output_1[key].shape) == 0: + self.assertEqual(output_1[key].item(), output_2[key].item()) + else: + self.assertListAlmostEqual(output_1[key].squeeze().tolist(), output_2[key].squeeze().tolist()) + + @slow + def test_to_eng_text(self): + model = SeamlessM4TModel.from_pretrained(self.repo_id).to(torch_device) + + # test text - tgt lang: eng + + # fmt: off + expected_text_tokens = [3, 256047, 3291, 248116, 248066, 9, 7356, 248075, 3] + # fmt: on + + # fmt: off + expected_unit_tokens = [ + 2,10051,8980,8212,949,1270,4311,1123,5918,2333,5311,3882,2415,5284,1123,612,8816,6370,5386,7334,4345,5645, + 9437,5748,1378,9818,4319,7968,7375,2909,9119,5151,8728,5335,3896,4013,8939,8885,6048,9530,3167,5833,1072,693, + 431,9867,364,7909,4608,5938,1889,9984,7947,4944,6171,3767,9861,9169,1187,8365,4571,7635,7784,7635,800,2393, + 32,5380,5852,8289,2530,2762,1833,2056,3553,4641,3553,5683,370,2288,1344,1518,7534,703,8359,7699,2 + ] + # fmt: on + + # fmt: off + expected_wav_slice = [-3e-05, -0.0004, -0.00037, -0.00013, -6e-05, 0.00012, -0.00016, 0.00025, 7e-05, -3e-05] + # fmt: on + + set_seed(0) + output = model.generate(**self.input_text, num_beams=1, tgt_lang="eng", return_intermediate_token_ids=True) + + self.assertListEqual(expected_text_tokens, output.sequences.squeeze().tolist()) + # FOR NOW, only first units correspondance + self.assertListEqual(expected_unit_tokens[:10], output.unit_sequences.squeeze().tolist()[:10]) + + self.assertListAlmostEqual(expected_wav_slice, output.waveform.squeeze().tolist()[50:60]) + + @slow + def test_to_swh_text(self): + model = SeamlessM4TModel.from_pretrained(self.repo_id).to(torch_device) + + # test text - tgt lang: swh + + # fmt: off + expected_text_tokens = [3, 256168, 1665, 188589, 7040, 248075, 3] + # fmt: on + + # fmt: off + expected_unit_tokens = [ + 2,10071,5729,9995,3089,7546,1204,1721,2532,4340,5623,3496,432,7730,9096,7677,3143,8211,6447,8399,4248,3565, + 4529,7700,9308,217,6476,3485,9667,3194,8476,4923,5593,1148,4466,7416,4872,463,4872,253,2348,4640,3450,2133, + 6318,2806,817,7613,2698,6563,8712,8344,9286,6878,6387,4281,6387,640,6387,3200,640,8355,640,6708,979,1738,2 + ] + # fmt: on + + # fmt: off + expected_wav_slice = [1e-05, -7e-05, -4e-05, -4e-05, -6e-05, -9e-05, -0.0001, -2e-05, -7e-05, -2e-05] + # fmt: on + + set_seed(0) + output = model.generate(**self.input_text, num_beams=1, tgt_lang="swh", return_intermediate_token_ids=True) + + self.assertListEqual(expected_text_tokens, output.sequences.squeeze().tolist()) + self.assertListEqual(expected_unit_tokens[:10], output.unit_sequences.squeeze().tolist()[:10]) + + self.assertListAlmostEqual(expected_wav_slice, output.waveform.squeeze().tolist()[50:60]) + + @slow + def test_to_rus_speech(self): + model = SeamlessM4TModel.from_pretrained(self.repo_id).to(torch_device) + + # test audio - tgt lang: rus + + # fmt: off + expected_text_tokens = [3, 256147, 1197, 73565, 3413, 537, 233331, 248075, 3] + # fmt: on + + # fmt: off + expected_unit_tokens = [ + 2, 10067, 5729, 4798, 9631, 8378, 4446, 2393, 6901, 5983, 2817, 4629, 8532, 1991, 2931, 8576, 8857, 5936, 4317, + 9000, 7740, 7995, 1225, 5980, 6094, 1420, 5373, 8771, 6600, 4487, 7029, 3630, 6740, 4870, 1483, 3003, 5585, 5511, + 7465, 3222, 32, 6272, 1950, 3120, 5368, 639, 3713, 5935, 7943, 567, 6129, 6822, 1226, 5063, 9878, 7756, 8825, 1078, 5943, + 457, 9282, 9668, 817, 7613, 2698, 6563, 8712, 8704, 9286, 8704, 6387, 4281, 6387, 640, 3200, 6387, 640, 8355, 6708, 979, 1738, 2 + ] + # fmt: on + + # fmt: off + expected_wav_slice = [0.00013, 0.00012, 0.00014, 3e-05, 0.0, -6e-05, -0.00018, -0.00016, -0.00021, -0.00018] + # fmt: on + + set_seed(0) + output = model.generate(**self.input_audio, num_beams=1, tgt_lang="rus", return_intermediate_token_ids=True) + + self.assertListEqual(expected_text_tokens, output.sequences.squeeze().tolist()) + self.assertListEqual(expected_unit_tokens[:10], output.unit_sequences.squeeze().tolist()[:10]) + + self.assertListAlmostEqual(expected_wav_slice, output.waveform.squeeze().tolist()[50:60]) + + @slow + def test_text_to_text_model(self): + kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False} + kwargs2 = { + "tgt_lang": "eng", + "output_hidden_states": True, + "return_dict_in_generate": True, + "output_scores": True, + } + self.factory_test_task(SeamlessM4TModel, SeamlessM4TForTextToText, self.input_text, kwargs1, kwargs2) + + @slow + def test_speech_to_text_model(self): + kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False} + kwargs2 = { + "tgt_lang": "eng", + "output_hidden_states": True, + "return_dict_in_generate": True, + "output_scores": True, + } + self.factory_test_task(SeamlessM4TModel, SeamlessM4TForSpeechToText, self.input_audio, kwargs1, kwargs2) + + @slow + def test_speech_to_speech_model(self): + kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True} + self.factory_test_task(SeamlessM4TModel, SeamlessM4TForSpeechToSpeech, self.input_audio, kwargs1, kwargs1) + + @slow + def test_text_to_speech_model(self): + kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True} + + self.factory_test_task(SeamlessM4TModel, SeamlessM4TForTextToSpeech, self.input_text, kwargs1, kwargs1) diff --git a/tests/models/seamless_m4t/test_processor_seamless_m4t.py b/tests/models/seamless_m4t/test_processor_seamless_m4t.py new file mode 100644 index 00000000000000..7beefb16bda7ea --- /dev/null +++ b/tests/models/seamless_m4t/test_processor_seamless_m4t.py @@ -0,0 +1,126 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +from transformers import SeamlessM4TFeatureExtractor, SeamlessM4TProcessor +from transformers.models.seamless_m4t import ( + SeamlessM4TTokenizer, + SeamlessM4TTokenizerFast, +) +from transformers.testing_utils import require_torch + +from .test_feature_extraction_seamless_m4t import floats_list + + +@require_torch +class SeamlessM4TProcessorTest(unittest.TestCase): + def setUp(self): + self.checkpoint = "facebook/hf-seamless-m4t-medium" + self.tmpdirname = tempfile.mkdtemp() + + def get_tokenizer(self, **kwargs): + return SeamlessM4TTokenizer.from_pretrained(self.checkpoint, **kwargs) + + def get_feature_extractor(self, **kwargs): + return SeamlessM4TFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_save_load_pretrained_default(self): + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor = SeamlessM4TProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + processor.save_pretrained(self.tmpdirname) + processor = SeamlessM4TProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + tokenizer_instance = isinstance(processor.tokenizer, SeamlessM4TTokenizerFast) or isinstance( + processor.tokenizer, SeamlessM4TTokenizer + ) + self.assertTrue(tokenizer_instance) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor.feature_extractor, SeamlessM4TFeatureExtractor) + + def test_save_load_pretrained_additional_features(self): + processor = SeamlessM4TProcessor( + tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor() + ) + processor.save_pretrained(self.tmpdirname) + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0) + + processor = SeamlessM4TProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 + ) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, SeamlessM4TFeatureExtractor) + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + + tokenizer_instance = isinstance(processor.tokenizer, SeamlessM4TTokenizerFast) or isinstance( + processor.tokenizer, SeamlessM4TTokenizer + ) + self.assertTrue(tokenizer_instance) + + # Copied from test.models.whisper.test_processor_whisper.WhisperProcessorTest.test_feature_extractor with Whisper->SeamlessM4T + def test_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = SeamlessM4TProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + raw_speech = floats_list((3, 1000)) + + input_feat_extract = feature_extractor(raw_speech, return_tensors="np") + input_processor = processor(audios=raw_speech, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + # Copied from test.models.whisper.test_processor_whisper.WhisperProcessorTest.test_tokenizer with Whisper->SeamlessM4T + def test_tokenizer(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = SeamlessM4TProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "This is a test string" + + encoded_processor = processor(text=input_str) + + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + # Copied from test.models.whisper.test_processor_whisper.WhisperProcessorTest.test_tokenizer_decode with Whisper->SeamlessM4T + def test_tokenizer_decode(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = SeamlessM4TProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) diff --git a/tests/models/seamless_m4t/test_tokenization_seamless_m4t.py b/tests/models/seamless_m4t/test_tokenization_seamless_m4t.py new file mode 100644 index 00000000000000..2cd9e8c56b52e6 --- /dev/null +++ b/tests/models/seamless_m4t/test_tokenization_seamless_m4t.py @@ -0,0 +1,672 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +from transformers import ( + SPIECE_UNDERLINE, + AddedToken, + BatchEncoding, + PreTrainedTokenizerFast, + SeamlessM4TTokenizer, + SeamlessM4TTokenizerFast, + is_torch_available, +) +from transformers.testing_utils import ( + get_tests_dir, + nested_simplify, + require_sentencepiece, + require_tokenizers, + require_torch, +) + +from ...test_tokenization_common import TokenizerTesterMixin + + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +if is_torch_available(): + from transformers.models.m2m_100.modeling_m2m_100 import shift_tokens_right + +EN_CODE = 256047 +RO_CODE = 256145 + +SMALL_TRAINING_CORPUS = [ + ["This is the first sentence.", "This is the second one."], + ["This sentence (contains #) over symbols and numbers 12 3.", "But not this one."], +] + + +@require_sentencepiece +@require_tokenizers +class SeamlessM4TTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = SeamlessM4TTokenizer + rust_tokenizer_class = SeamlessM4TTokenizerFast + test_rust_tokenizer = True + test_sentencepiece = True + from_pretrained_kwargs = {} + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = SeamlessM4TTokenizer(SAMPLE_VOCAB, keep_accents=True) + tokenizer.save_pretrained(self.tmpdirname) + + def test_full_tokenizer(self): + tokenizer = SeamlessM4TTokenizer(SAMPLE_VOCAB, keep_accents=True) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, + [ + value + tokenizer.fairseq_offset + for value in [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4] + ], + ) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + + def test_maximum_encoding_length_single_input(self): + tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + seq_0, ids = self.get_clean_sequence(tokenizer, max_length=20) + + sequence = tokenizer.encode(seq_0, add_special_tokens=False) + total_length = len(sequence) + + self.assertGreater( + total_length, 4, "Issue with the testing sequence, please update it, it's too short" + ) + + # Test with max model input length + model_max_length = tokenizer.model_max_length + self.assertEqual(model_max_length, 100) + seq_1 = seq_0 * model_max_length + + sequence1 = tokenizer(seq_1, add_special_tokens=False) + total_length1 = len(sequence1["input_ids"]) + self.assertGreater( + total_length1, + model_max_length, + "Issue with the testing sequence, please update it, it's too short", + ) + + # Simple + padding_strategies = ( + [False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False] + ) + for padding_state in padding_strategies: + with self.subTest(f"Padding: {padding_state}"): + for truncation_state in [True, "longest_first", "only_first"]: + with self.subTest(f"Truncation: {truncation_state}"): + output = tokenizer(seq_1, padding=padding_state, truncation=truncation_state) + self.assertEqual(len(output["input_ids"]), model_max_length) + + output = tokenizer([seq_1], padding=padding_state, truncation=truncation_state) + self.assertEqual(len(output["input_ids"][0]), model_max_length) + + # Simple with no truncation + # Reset warnings + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer(seq_1, padding=padding_state, truncation=False) + self.assertNotEqual(len(output["input_ids"]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length" + " for this model" + ) + ) + + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer([seq_1], padding=padding_state, truncation=False) + self.assertNotEqual(len(output["input_ids"][0]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length" + " for this model" + ) + ) + + # Overflowing tokens + stride = 2 + + # modify padding because it's activated by default in seamlessM4T + information = tokenizer( + seq_0, + max_length=total_length - 2, + add_special_tokens=False, + stride=stride, + truncation="longest_first", + return_overflowing_tokens=True, + padding=False, + # add_prefix_space=False, + ) + + # Overflowing tokens are handled quite differently in slow and fast tokenizers + if isinstance(tokenizer, PreTrainedTokenizerFast): + truncated_sequence = information["input_ids"][0] + overflowing_tokens = information["input_ids"][1] + self.assertEqual(len(information["input_ids"]), 2) + + self.assertEqual(len(truncated_sequence), total_length - 2) + self.assertEqual(truncated_sequence, sequence[:-2]) + + self.assertEqual(len(overflowing_tokens), 2 + stride) + self.assertEqual(overflowing_tokens, sequence[-(2 + stride) :]) + else: + truncated_sequence = information["input_ids"] + overflowing_tokens = information["overflowing_tokens"] + + self.assertEqual(len(truncated_sequence), total_length - 2) + self.assertEqual(truncated_sequence, sequence[:-2]) + + self.assertEqual(len(overflowing_tokens), 2 + stride) + self.assertEqual(overflowing_tokens, sequence[-(2 + stride) :]) + + @unittest.skip("By defaults, uses pad_to_multiple_of which breaks the test") + def test_maximum_encoding_length_pair_input(self): + pass + + def test_padding_to_multiple_of(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + if tokenizer.pad_token is None: + self.skipTest("No padding token.") + else: + empty_tokens = tokenizer("", padding=True, pad_to_multiple_of=8) + normal_tokens = tokenizer("This is a sample input", padding=True, pad_to_multiple_of=8) + for key, value in empty_tokens.items(): + self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + # default to padding=True so need to precise which padding is called + normal_tokens = tokenizer("This", pad_to_multiple_of=8, padding=False) + for key, value in normal_tokens.items(): + self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + # Should also work with truncation + normal_tokens = tokenizer("This", padding=True, truncation=True, pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + # truncation to something which is not a multiple of pad_to_multiple_of raises an error + self.assertRaises( + ValueError, + tokenizer.__call__, + "This", + padding=True, + truncation=True, + max_length=12, + pad_to_multiple_of=8, + ) + + @require_torch + def test_prepare_seq2seq_batch(self): + if not self.test_seq2seq: + return + + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # Longer text that will definitely require truncation. + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for" + " Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons" + " will only worsen the violence and misery for millions of people.", + ] + tgt_text = [ + "Şeful ONU declară că nu există o soluţie militară în Siria", + "Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al" + ' Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi' + " că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.", + ] + try: + batch = tokenizer.prepare_seq2seq_batch( + src_texts=src_text, + tgt_texts=tgt_text, + max_length=3, + max_target_length=10, + return_tensors="pt", + src_lang="eng", + tgt_lang="ron", + pad_to_multiple_of=None, + ) + except NotImplementedError: + return + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.labels.shape[1], 10) + + # TODO: not working for tgt_text + # max_target_length will default to max_length if not specified + batch = tokenizer.prepare_seq2seq_batch( + src_texts=src_text, + tgt_texts=tgt_text, + max_length=4, + return_tensors="pt", + pad_to_multiple_of=None, + ) + self.assertEqual(batch.input_ids.shape[1], 4) + self.assertEqual(batch.labels.shape[1], 4) + + batch_encoder_only = tokenizer.prepare_seq2seq_batch( + src_texts=src_text, + max_length=4, + max_target_length=10, + return_tensors="pt", + pad_to_multiple_of=None, + ) + self.assertEqual(batch_encoder_only.input_ids.shape[1], 4) + self.assertEqual(batch_encoder_only.attention_mask.shape[1], 4) + self.assertNotIn("decoder_input_ids", batch_encoder_only) + + @unittest.skip("Unfortunately way too slow to build a BPE with SentencePiece.") + def test_save_slow_from_fast_and_reload_fast(self): + pass + + # Copied from tests.models.nllb.test_tokenization_nllb.NllbTokenizationTest.test_special_tokens_initialization + def test_special_tokens_initialization(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + added_tokens = [AddedToken("", lstrip=True)] + + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + r_output = tokenizer_r.encode("Hey this is a token") + + special_token_id = tokenizer_r.encode("", add_special_tokens=False)[0] + + self.assertTrue(special_token_id in r_output) + + if self.test_slow_tokenizer: + tokenizer_cr = self.rust_tokenizer_class.from_pretrained( + pretrained_name, + additional_special_tokens=added_tokens, + **kwargs, # , from_slow=True <- unfortunately too slow to convert + ) + tokenizer_p = self.tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + + p_output = tokenizer_p.encode("Hey this is a token") + + cr_output = tokenizer_cr.encode("Hey this is a token") + + self.assertEqual(p_output, r_output) + self.assertEqual(cr_output, r_output) + self.assertTrue(special_token_id in p_output) + self.assertTrue(special_token_id in cr_output) + + @unittest.skip( + "encode_plus and batch_encode_plus are deprecated and __call__ do some processing, so we expect different results." + ) + def test_call(self): + pass + + def test_training_new_tokenizer(self): + # This feature only exists for fast tokenizers + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_rust_tokenizer() + new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100) + + # Test we can use the new tokenizer with something not seen during training + inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."]) + self.assertEqual(len(inputs["input_ids"]), 2) + decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + expected_result = "This is the first sentence" + + if tokenizer.backend_tokenizer.normalizer is not None: + expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result) + self.assertEqual(expected_result, decoded_input) + + # We check that the parameters of the tokenizer remained the same + # Check we have the same number of added_tokens for both pair and non-pair inputs. + # make sure it has the same prefix tokens first + new_tokenizer.tgt_lang = tokenizer.tgt_lang + tokenizer.tgt_lang = tokenizer.tgt_lang + self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False)) + self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True)) + + # Check we have the correct max_length for both pair and non-pair inputs. + self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence) + self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair) + + # Assert the set of special tokens match as we didn't ask to change them + self.assertSequenceEqual( + tokenizer.all_special_tokens_extended, + new_tokenizer.all_special_tokens_extended, + ) + + self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map) + + @unittest.skip("Fails because of the hack of adding in _tokenize") + def test_pickle_subword_regularization_tokenizer(self): + pass + + @unittest.skip("Fails because of the hack of adding in _tokenize") + def test_subword_regularization_tokenizer(self): + pass + + +@require_torch +@require_sentencepiece +@require_tokenizers +class SeamlessM4TDistilledIntegrationTest(unittest.TestCase): + checkpoint_name = "facebook/hf-seamless-m4t-medium" + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + """ Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""", + ] + tgt_text = [ + "Şeful ONU declară că nu există o soluţie militară în Siria", + "Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei" + ' pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor' + " face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.", + ] + + # fmt: off + expected_src_tokens = [256047, 16297, 134408, 8165, 248066, 14734, 950, 1135, 105721, 3573, 83, 27352, 108, 49486, 3] + # fmt: on + + @classmethod + def setUpClass(cls): + cls.tokenizer: SeamlessM4TTokenizer = SeamlessM4TTokenizer.from_pretrained( + cls.checkpoint_name, src_lang="eng", tgt_lang="ron" + ) + # cls.pad_token_id = 1 + return cls + + def test_language_codes(self): + self.assertEqual(self.tokenizer.convert_tokens_to_ids("__ace_Latn__"), 256002) + self.assertEqual(self.tokenizer.convert_tokens_to_ids("__shn__"), 256152) + self.assertEqual(self.tokenizer.convert_tokens_to_ids("__eng__"), 256047) + self.assertEqual(self.tokenizer.convert_tokens_to_ids("__fra__"), 256057) + self.assertEqual(self.tokenizer.convert_tokens_to_ids("__quy__"), 256144) + + def test_tokenizer_tgt_lang(self): + ids = self.tokenizer(self.src_text, src_lang="fra").input_ids[0] + self.assertListEqual(self.expected_src_tokens[1:], ids[1 : len(self.expected_src_tokens)]) + self.assertEqual(256057, ids[0]) + + rest_ids = ids[len(self.expected_src_tokens) :] + self.assertListEqual([0] * len(rest_ids), rest_ids) + + ids = self.tokenizer(self.src_text, src_lang="__shn__").input_ids[0] + self.assertListEqual(self.expected_src_tokens[1:], ids[1 : len(self.expected_src_tokens)]) + self.assertEqual(256152, ids[0]) + + # Copied from tests.models.nllb.test_tokenization_nllb.NllbDistilledIntegrationTest.test_enro_tokenizer_decode_ignores_language_codes + def test_enro_tokenizer_decode_ignores_language_codes(self): + self.assertIn(RO_CODE, self.tokenizer.all_special_ids) + # fmt: off + generated_ids = [RO_CODE, 4254, 98068, 112923, 39072, 3909, 713, 102767, 26, 17314, 35642, 14683, 33118, 2022, 66987, 2, 256047] + # fmt: on + + result = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True) + self.assertEqual(result, expected_romanian) + self.assertNotIn(self.tokenizer.eos_token, result) + + def test_enro_tokenizer_truncation(self): + src_text = ["this is gunna be a long sentence " * 20] + assert isinstance(src_text[0], str) + desired_max_length = 10 + ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0] + self.assertEqual(ids[-1], 3) + self.assertEqual(ids[0], EN_CODE) + self.assertEqual(len(ids), desired_max_length) + + # Copied from tests.models.nllb.test_tokenization_nllb.NllbDistilledIntegrationTest.test_special_tokens_unaffacted_by_save_load with fairseq_tokens_to_ids->additional_special_tokens, Nllb->SeamlessM4T, Dict->List + def test_special_tokens_unaffacted_by_save_load(self): + tmpdirname = tempfile.mkdtemp() + original_special_tokens = self.tokenizer.additional_special_tokens + self.tokenizer.save_pretrained(tmpdirname) + new_tok = SeamlessM4TTokenizer.from_pretrained(tmpdirname) + self.assertListEqual(new_tok.additional_special_tokens, original_special_tokens) + + @require_torch + def test_enro_tokenizer_prepare_batch(self): + batch = self.tokenizer( + self.src_text, + text_target=self.tgt_text, + padding=True, + truncation=True, + max_length=len(self.expected_src_tokens), + pad_to_multiple_of=None, + return_tensors="pt", + ) + batch["decoder_input_ids"] = shift_tokens_right( + batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.convert_tokens_to_ids("__ron__") + ) + + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 15), batch.input_ids.shape) + self.assertEqual((2, 15), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(self.expected_src_tokens, result) + self.assertEqual(RO_CODE, batch.decoder_input_ids[0, 0]) # EOS + # Test that special tokens are reset + self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE]) + self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) + + def test_seq2seq_max_length(self): + batch = self.tokenizer( + self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt", pad_to_multiple_of=None + ) + targets = self.tokenizer( + text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt" + ) + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right( + labels, + self.tokenizer.pad_token_id, + decoder_start_token_id=self.tokenizer.convert_tokens_to_ids(self.tokenizer.tgt_lang), + ) + + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 10) + + @require_torch + def test_tokenizer_translation(self): + inputs = self.tokenizer._build_translation_inputs( + "A test", return_tensors="pt", src_lang="eng", tgt_lang="fra" + ) + + self.assertEqual( + nested_simplify(inputs), + { + # A, test, EOS, en_XX + "input_ids": [[256047, 70, 7356, 3]], + "attention_mask": [[1, 1, 1, 1]], + # ar_AR + "forced_bos_token_id": 256057, + }, + ) + + +@require_sentencepiece +@require_tokenizers +class CommonSpmIntegrationTests(unittest.TestCase): + """ + A class that regroups important test to make sure that we properly handle the special tokens. + """ + + @classmethod + def setUpClass(cls): + tokenizer = SeamlessM4TTokenizer(SAMPLE_VOCAB, extra_ids=0, add_bos_token=False, legacy=False) + tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("", rstrip=False, lstrip=False)]}) + cls.tokenizer = tokenizer + return cls + + def test_add_dummy_prefix(self): + # make sure `'▁'` is prepended, and outputs match sp_model's + # `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute + input_ids = self.tokenizer.encode(". Hello") + self.assertEqual(input_ids, [3, 1, 8, 5, 157, 87, 21, 3]) + sp_encode = self.tokenizer.sp_model.encode(". Hello") + + # [bos, lang_id, _] + offset_sp_encode + self.assertEqual(input_ids[:-1], [3, 1, 8] + [i + self.tokenizer.fairseq_offset for i in sp_encode]) + tokens = self.tokenizer.tokenize(". Hello") + self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"]) + + tokens = self.tokenizer.tokenize("") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str)) + + tokens = self.tokenizer.tokenize(" ") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode(" ", out_type=str)) + + tokens = self.tokenizer.tokenize("▁") + self.assertEqual(tokens, []) + self.assertEqual(tokens, self.tokenizer.sp_model.encode("▁", out_type=str)) + + def test_remove_extra_whitespaces(self): + # make sure the extra spaces are eaten. Since the sample vocab does not have + # `______`. sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute is set to False + + input_ids = self.tokenizer.encode(" . Hello") + self.assertEqual(input_ids, [3, 1, 8, 5, 157, 87, 21, 3]) + sp_encode = self.tokenizer.sp_model.encode(" . Hello") + self.assertEqual([i - self.tokenizer.fairseq_offset for i in input_ids[2:-1]], [7] + sp_encode) + tokens = self.tokenizer.tokenize(" . Hello") + self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"]) + + # `'▁'` is also a whitespace + input_ids = self.tokenizer.encode("▁He is not") + self.assertEqual(input_ids, [3, 1, 157, 47, 45, 3]) + tokens = self.tokenizer.tokenize("▁He is not") + sp_encode = [ + self.tokenizer.sp_model.piece_to_id("▁He"), + self.tokenizer.sp_model.piece_to_id("▁is"), + self.tokenizer.sp_model.piece_to_id("▁not"), + ] + self.assertEqual([i - self.tokenizer.fairseq_offset for i in input_ids[2:-1]], sp_encode) + self.assertEqual(tokens, ["▁He", "▁is", "▁not"]) # no extra space added + + input_ids = self.tokenizer.encode("▁He is not ▁He") + self.assertEqual(input_ids, [3, 1, 157, 47, 45, 2, 157, 3]) + tokens = self.tokenizer.tokenize("▁He is not ▁He") + self.assertEqual(tokens, ["▁He", "▁is", "▁not", "", "▁He"]) # spaces are eaten by spm + our strip + # make sure that the output after the extra id is the same as if + # extra_id was not there + input_ids = self.tokenizer.encode("▁He is not ▁He") + self.assertEqual(input_ids, [3, 1, 157, 47, 45, 157, 3]) + tokens = self.tokenizer.tokenize("▁He is not ▁He") + self.assertEqual(tokens, ["▁He", "▁is", "▁not", "▁He"]) # spaces are eaten by spm even if not start + + def test_character_after_special_token(self): + # Make sure that `tokenizer.tokenize` is similar to + # adding the equivalent special token to the vocab + input_ids = self.tokenizer.encode("Hey I") + self.assertEqual(input_ids, [3, 1, 157, 31, 2, 101, 3]) + sp_encode = self.tokenizer.sp_model.encode("Hey .I") + + # the last token besides eos should be 100 offset + self.assertEqual(input_ids[-2] - self.tokenizer.fairseq_offset, sp_encode[-1]) + tokens = self.tokenizer.tokenize("I") + self.assertEqual(tokens, ["", "I"]) + + input_ids = self.tokenizer.encode("Hello, ,") + self.assertEqual(input_ids, [3, 1, 157, 87, 21, 4, 2, 4, 3]) + tokens = self.tokenizer.tokenize("Hello, ,") + self.assertEqual(tokens, ["▁He", "ll", "o", ",", "", ","]) + + def test_special_tokens_strip(self): + input_ids = self.tokenizer.encode(" ,") + self.assertEqual(input_ids, [3, 1, 2, 8, 4, 3]) + tokens = self.tokenizer.tokenize(" ,") + # spaces are eaten by rstrip / lstrip + spm sp_model.encode(" ") = [] + self.assertEqual(tokens, ["", "▁", ","]) + + input_ids = self.tokenizer.encode("No ▁He") + self.assertEqual(input_ids, [3, 1, 285, 2, 157, 3]) + tokens = self.tokenizer.tokenize("No ▁He") + self.assertEqual(tokens, ["▁No", "", "▁He"]) # spaces are eaten by rstrip / lstrip diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index b77028b2945465..7a116d5af3178b 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -84,6 +84,18 @@ "ClapAudioConfig": ["num_classes"], # Not used, but providing useful information to users "SpeechT5HifiGanConfig": ["sampling_rate"], + # Actually used in the config or generation config, in that case necessary for the sub-components generation + "SeamlessM4TConfig": [ + "max_new_tokens", + "t2u_max_new_tokens", + "t2u_decoder_attention_heads", + "t2u_decoder_ffn_dim", + "t2u_decoder_layers", + "t2u_encoder_attention_heads", + "t2u_encoder_ffn_dim", + "t2u_encoder_layers", + "t2u_max_position_embeddings", + ], } diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index aeacafbc289e8c..0552988c7a4cc2 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -463,6 +463,7 @@ "SEWForCTC", "SamConfig", "SamPromptEncoderConfig", + "SeamlessM4TConfig", # use of unconventional markdown "Seq2SeqTrainingArguments", "SpecialTokensMixin", "Speech2Text2Config", diff --git a/utils/check_repo.py b/utils/check_repo.py index 1e5fc5ce216098..329e9ec076e9df 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -112,6 +112,9 @@ "BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model. "BarkCausalModel", # Building part of bigger (tested) model. "BarkModel", # Does not have a forward signature - generation tested with integration tests + "SeamlessM4TTextToUnitModel", # Building part of bigger (tested) model. + "SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model. + "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. ] # Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't @@ -281,6 +284,10 @@ "SpeechT5ForTextToSpeech", "SpeechT5HifiGan", "VitMatteForImageMatting", + "SeamlessM4TTextToUnitModel", + "SeamlessM4TTextToUnitForConditionalGeneration", + "SeamlessM4TCodeHifiGan", + "SeamlessM4TForSpeechToSpeech", # no auto class for speech-to-speech ] # DO NOT edit this list! diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index c2cedb66b2176b..9ee13cc0836c4b 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -768,6 +768,7 @@ src/transformers/models/sam/image_processing_sam.py src/transformers/models/sam/modeling_sam.py src/transformers/models/sam/modeling_tf_sam.py src/transformers/models/sam/processing_sam.py +src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py src/transformers/models/segformer/configuration_segformer.py src/transformers/models/segformer/convert_segformer_original_to_pytorch.py src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py diff --git a/utils/slow_documentation_tests.txt b/utils/slow_documentation_tests.txt index 104929f816a60b..83e400ba7e9546 100644 --- a/utils/slow_documentation_tests.txt +++ b/utils/slow_documentation_tests.txt @@ -1,5 +1,6 @@ docs/source/en/generation_strategies.md docs/source/en/model_doc/ctrl.md +docs/source/en/model_doc/seamless_m4t.md docs/source/en/task_summary.md docs/source/en/tasks/prompting.md src/transformers/models/blip_2/modeling_blip_2.py From 244a53e0f6a8d95d429559cfc49a07a4e85cc680 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:54:22 +0200 Subject: [PATCH 08/38] [`NLLB-MoE`] Fix NLLB MoE 4bit inference (#27012) fix NLLB MoE 4bit --- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 3701bbecef2e73..a88d53a340f4f6 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -417,7 +417,7 @@ def forward(self, hidden_states): if ( isinstance(self.fc2.weight, torch.Tensor) and hidden_states.dtype != self.fc2.weight.dtype - and self.fc2.weight.dtype != torch.int8 + and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8) ): hidden_states = hidden_states.to(self.fc2.weight.dtype) hidden_states = self.fc2(hidden_states) From f9f27b0fc229a096418d24af04b8966408b6c2d2 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 23 Oct 2023 15:25:06 +0200 Subject: [PATCH 09/38] [`SeamlessM4T`] fix copies with NLLB MoE int8 (#27018) fix copies on newly merged model --- src/transformers/models/seamless_m4t/modeling_seamless_m4t.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 7df6fcd98907ca..91ec6a8f9b8795 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1300,7 +1300,7 @@ def forward(self, hidden_states): if ( isinstance(self.fc2.weight, torch.Tensor) and hidden_states.dtype != self.fc2.weight.dtype - and self.fc2.weight.dtype != torch.int8 + and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8) ): hidden_states = hidden_states.to(self.fc2.weight.dtype) hidden_states = self.fc2(hidden_states) From c0b5ad94736093d1cdecb76a627494ca8ce1a162 Mon Sep 17 00:00:00 2001 From: Rafael Padilla <31217453+rafaelpadilla@users.noreply.github.com> Date: Mon, 23 Oct 2023 11:08:39 -0300 Subject: [PATCH 10/38] small typos found (#26988) just very small typos found --- src/transformers/image_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 3cea0c2d17698b..e3d0d23474d322 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -189,7 +189,7 @@ def to_pil_image( elif not isinstance(image, np.ndarray): raise ValueError("Input image type not supported: {}".format(type(image))) - # If the channel as been moved to first dim, we put it back at the end. + # If the channel has been moved to first dim, we put it back at the end. image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) # If there is a single channel, we squeeze it, as otherwise PIL can't handle it. @@ -593,7 +593,7 @@ def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: """ Converts bounding boxes from corners format to center format. - corners format: contains the coodinates for the top-left and bottom-right corners of the box + corners format: contains the coordinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y) center format: contains the coordinate for the center of the box and its the width, height dimensions (center_x, center_y, width, height) From f7354a3bd6ee5c5224fbc869c593eb5c4aa197ab Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 23 Oct 2023 16:18:02 +0100 Subject: [PATCH 11/38] Remove token_type_ids from default TF GPT-2 signature (#26962) Remove token_type_ids from default GPT-2 signature --- src/transformers/models/gpt2/modeling_tf_gpt2.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 525207268e2279..824a49a1d41da4 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -521,6 +521,16 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel): # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model _keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias", r"h.\d+.crossattention.bias"] + @property + def input_signature(self): + # Although GPT-2 supports token_type_ids in theory, in practice they are rarely used, and the implementation + # means that passing token_type_ids=0 yields different outputs from token_type_ids=None. + # Therefore, we remove the token_type_ids argument by default, even though it would usually be included. + return { + "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), + } + @dataclass class TFGPT2DoubleHeadsModelOutput(ModelOutput): From f09a081d2765c6535256b0e2d65bf54fc03f7fee Mon Sep 17 00:00:00 2001 From: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com> Date: Mon, 23 Oct 2023 23:58:00 +0800 Subject: [PATCH 12/38] Translate `pipeline_tutorial.md` to chinese (#26954) * update translation of pipeline_tutorial and preprocessing(Version1.0) * update translation of pipeline_tutorial and preprocessing(Version2.0) * update translation docs * update to fix problems mentioned in review --------- Co-authored-by: jiaqiw --- docs/source/zh/_toctree.yml | 4 +- docs/source/zh/pipeline_tutorial.md | 308 ++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 docs/source/zh/pipeline_tutorial.md diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 2b885d48f0d8ab..67cbc397bb0d62 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -9,8 +9,10 @@ - sections: - local: accelerate title: 加速分布式训练 + - local: pipeline_tutorial + title: pipeline教程 title: 教程 - sections: - local: fast_tokenizers title: 使用 🤗 Tokenizers 中的分词器 - title: 开发者指南 \ No newline at end of file + title: 开发者指南 diff --git a/docs/source/zh/pipeline_tutorial.md b/docs/source/zh/pipeline_tutorial.md new file mode 100644 index 00000000000000..01e621840cd3c8 --- /dev/null +++ b/docs/source/zh/pipeline_tutorial.md @@ -0,0 +1,308 @@ + + +# 推理pipeline + +[`pipeline`] 让使用[Hub](https://huggingface.co/models)上的任何模型进行任何语言、计算机视觉、语音以及多模态任务的推理变得非常简单。即使您对特定的模态没有经验,或者不熟悉模型的源码,您仍然可以使用[`pipeline`]进行推理!本教程将教您: + +- 如何使用[`pipeline`] 进行推理。 +- 如何使用特定的`tokenizer`(分词器)或模型。 +- 如何使用[`pipeline`] 进行音频、视觉和多模态任务的推理。 + + + +请查看[`pipeline`]文档以获取已支持的任务和可用参数的完整列表。 + + + +## Pipeline使用 + +虽然每个任务都有一个关联的[`pipeline`],但使用通用的抽象的[`pipeline`]更加简单,其中包含所有特定任务的`pipelines`。[`pipeline`]会自动加载一个默认模型和一个能够进行任务推理的预处理类。让我们以使用[`pipeline`]进行自动语音识别(ASR)或语音转文本为例。 + +1. 首先,创建一个[`pipeline`]并指定推理任务: + +```py +>>> from transformers import pipeline + +>>> transcriber = pipeline(task="automatic-speech-recognition") +``` + +2. 将您的输入传递给[`pipeline`]。对于语音识别,这通常是一个音频输入文件: + + +```py +>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac") +{'text': 'I HAVE A DREAM BUT ONE DAY THIS NATION WILL RISE UP LIVE UP THE TRUE MEANING OF ITS TREES'} +``` + +您没有得到您期望的结果?可以在Hub上查看一些[最受欢迎的自动语音识别模型](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&sort=trending) +,看看是否可以获得更好的转录。 + +让我们尝试来自 OpenAI 的[Whisper large-v2](https://huggingface.co/openai/whisper-large) 模型。Whisperb比Wav2Vec2晚2年发布,使用接近10倍的数据进行了训练。因此,它在大多数下游基准测试上击败了Wav2Vec2。 +它还具有预测标点和大小写的附加优势,而Wav2Vec2则无法实现这些功能。 + +让我们在这里尝试一下,看看它的表现如何: + + +```py +>>> transcriber = pipeline(model="openai/whisper-large-v2") +>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac") +{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'} +``` + +现在这个结果看起来更准确了!要进行深入的Wav2Vec2与Whisper比较,请参阅[音频变换器课程](https://huggingface.co/learn/audio-course/chapter5/asr_models)。 +我们鼓励您在 Hub 上查看不同语言的模型,以及专业领域的模型等。您可以在Hub上直接查看并比较模型的结果,以确定是否适合或处理边缘情况是否比其他模型更好。如果您没有找到适用于您的用例的模型,您始终可以[训练](training)自己的模型! + +如果您有多个输入,您可以将输入作为列表传递: + + +```py +transcriber( + [ + "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac", + "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac", + ] +) +``` + +`Pipelines`非常适合用于测试,因为从一个模型切换到另一个模型非常琐碎;但是,还有一些方法可以将它们优化后用于大型工作负载而不仅仅是测试。请查看以下指南,深入探讨如何迭代整个数据集或在Web服务器中使用`Pipelines`: +* [在数据集上使用流水线](#using-pipelines-on-a-dataset) +* [在Web服务器中使用流水线](./pipeline_webserver) + + +## 参数 + +[`pipeline`] 支持许多参数;有些是适用于特定任务的,而有些适用于所有`pipeline`。通常情况下,您可以在任何地方指定对应参数: + + +```py +transcriber = pipeline(model="openai/whisper-large-v2", my_parameter=1) + +out = transcriber(...) # This will use `my_parameter=1`. +out = transcriber(..., my_parameter=2) # This will override and use `my_parameter=2`. +out = transcriber(...) # This will go back to using `my_parameter=1`. +``` + +让我们查看其中的三个重要参数: + + +### 设备 + +如果您使用 `device=n`,`pipeline`会自动将模型放在指定的设备上。无论您使用PyTorch还是Tensorflow,这都可以工作。 + + +```py +transcriber = pipeline(model="openai/whisper-large-v2", device=0) +``` + +如果模型对于单个GPU来说过于庞大,并且您正在使用PyTorch,您可以设置 `device_map="auto"` 以自动确定如何加载和存储模型权重。使用 `device_map` 参数需要安装🤗 [Accelerate](https://huggingface.co/docs/accelerate) 软件包: + + +```bash +pip install --upgrade accelerate +``` + +以下代码会自动在各个设备上加载和存储模型权重: + + +```py +transcriber = pipeline(model="openai/whisper-large-v2", device_map="auto") +``` + +请注意,如果传递了 `device_map="auto"`,在实例化您的 `pipeline` 时不需要添加 `device=device` 参数,否则可能会遇到一些意外的状况! + +### 批量大小 + +默认情况下,`pipelines`不会进行批量推理,原因在[这里](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching)详细解释。因为批处理不一定更快,实际上在某些情况下可能会更慢。 + +但如果在您的用例中起作用,您可以使用: + + +```py +transcriber = pipeline(model="openai/whisper-large-v2", device=0, batch_size=2) +audio_filenames = [f"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/{i}.flac" for i in range(1, 5)] +texts = transcriber(audio_filenames) +``` + +以上代码会在提供的4个音频文件上运行`pipeline`,它会将它们以2个一组的批次传递给模型(模型在GPU上,此时批处理更有可能有所帮助),而您无需编写额外的代码。输出应始终与没有批处理时收到的结果相一致。它只是一种帮助您更快地使用`pipeline`的方式。 + +`pipeline`也可以减轻一些批处理的复杂性,因为对于某些`pipeline`,需要将单个项目(如长音频文件)分成多个部分以供模型处理。`pipeline`为您执行这种[*chunk batching*](./main_classes/pipelines#pipeline-chunk-batching)。 + +### 任务特定参数 + +所有任务都提供了特定于任务的参数,这些参数提供额外的灵活性和选择,以帮助您完成工作。 +例如,[`transformers.AutomaticSpeechRecognitionPipeline.__call__`] 方法具有一个 `return_timestamps` 参数,对于字幕视频似乎很有帮助: + +```py +>>> transcriber = pipeline(model="openai/whisper-large-v2", return_timestamps=True) +>>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac") +{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.', 'chunks': [{'timestamp': (0.0, 11.88), 'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its'}, {'timestamp': (11.88, 12.38), 'text': ' creed.'}]} +``` + +正如您所看到的,模型推断出了文本,还输出了各个句子发音的**时间**。 + +每个任务都有许多可用的参数,因此请查看每个任务的API参考,以了解您可以进行哪些调整!例如,[`~transformers.AutomaticSpeechRecognitionPipeline`] 具有 `chunk_length_s` 参数,对于处理非常长的音频文件(例如,为整部电影或长达一小时的视频配字幕)非常有帮助,这通常是模型无法单独处理的: + +```python +>>> transcriber = pipeline(model="openai/whisper-large-v2", chunk_length_s=30, return_timestamps=True) +>>> transcriber("https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav") +{'text': " Chapter 16. I might have told you of the beginning of this liaison in a few lines, but I wanted you to see every step by which we came. I, too, agree to whatever Marguerite wished, Marguerite to be unable to live apart from me. It was the day after the evening... +``` + +如果您找不到一个真正有帮助的参数,欢迎[提出请求](https://github.com/huggingface/transformers/issues/new?assignees=&labels=feature&template=feature-request.yml)! + +## 在数据集上使用pipelines + +`pipelines`也可以对大型数据集进行推理。我们建议使用迭代器来完成这一任务,这是最简单的方法: + + +```py +def data(): + for i in range(1000): + yield f"My example {i}" + + +pipe = pipeline(model="gpt2", device=0) +generated_characters = 0 +for out in pipe(data()): + generated_characters += len(out[0]["generated_text"]) +``` + +迭代器 `data()` 会产生每个结果,`pipelines`会自动识别输入为可迭代对象,并在GPU上处理数据的同时开始获取数据(在底层使用[DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader))。这一点非常重要,因为您不必为整个数据集分配内存,可以尽可能快地将数据传送到GPU。 + +由于批处理可以加速处理,因此在这里尝试调整 `batch_size` 参数可能会很有用。 + +迭代数据集的最简单方法就是从🤗 [Datasets](https://github.com/huggingface/datasets/) 中加载数据集: + + +```py +# KeyDataset is a util that will just output the item we're interested in. +from transformers.pipelines.pt_utils import KeyDataset +from datasets import load_dataset + +pipe = pipeline(model="hf-internal-testing/tiny-random-wav2vec2", device=0) +dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:10]") + +for out in pipe(KeyDataset(dataset, "audio")): + print(out) +``` + + +## 在Web服务器上使用pipelines + + +创建推理引擎是一个复杂的主题,值得有自己的页面。 + + +[链接](./pipeline_webserver) + +## 视觉流水线 + +对于视觉任务,使用[`pipeline`] 几乎是相同的。 + +指定您的任务并将图像传递给分类器。图像可以是链接、本地路径或base64编码的图像。例如,下面显示的是哪种品种的猫? + +![pipeline-cat-chonk](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg) +```py +>>> from transformers import pipeline + +>>> vision_classifier = pipeline(model="google/vit-base-patch16-224") +>>> preds = vision_classifier( +... images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" +... ) +>>> preds = [{"score": round(pred["score"], 4), "label": pred["label"]} for pred in preds] +>>> preds +[{'score': 0.4335, 'label': 'lynx, catamount'}, {'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}, {'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}, {'score': 0.0239, 'label': 'Egyptian cat'}, {'score': 0.0229, 'label': 'tiger cat'}] +``` + +## 文本流水线 + +对于NLP任务,使用[`pipeline`] 几乎是相同的。 + + +```py +>>> from transformers import pipeline + +>>> # This model is a `zero-shot-classification` model. +>>> # It will classify text, except you are free to choose any label you might imagine +>>> classifier = pipeline(model="facebook/bart-large-mnli") +>>> classifier( +... "I have a problem with my iphone that needs to be resolved asap!!", +... candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"], +... ) +{'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'], 'scores': [0.504, 0.479, 0.013, 0.003, 0.002]} +``` + +## 多模态流水线 + +[`pipeline`] 支持多个模态。例如,视觉问题回答(VQA)任务结合了文本和图像。请随意使用您喜欢的任何图像链接和您想要问关于该图像的问题。图像可以是URL或图像的本地路径。 + +例如,如果您使用这个[invoice image](https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png): + + +```py +>>> from transformers import pipeline + +>>> vqa = pipeline(model="impira/layoutlm-document-qa") +>>> vqa( +... image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", +... question="What is the invoice number?", +... ) +[{'score': 0.42515, 'answer': 'us-001', 'start': 16, 'end': 16}] +``` + + + +要运行上面的示例,除了🤗 Transformers之外,您需要安装[`pytesseract`](https://pypi.org/project/pytesseract/)。 + + +```bash +sudo apt install -y tesseract-ocr +pip install pytesseract +``` + + + +## 在大模型上使用🤗 `accelerate`和`pipeline`: + +您可以轻松地使用🤗 `accelerate`在大模型上运行 `pipeline`!首先确保您已经使用 `pip install accelerate` 安装了 `accelerate`。 + +首先使用 `device_map="auto"` 加载您的模型!我们将在示例中使用 `facebook/opt-1.3b`。 + + +```py +# pip install accelerate +import torch +from transformers import pipeline + +pipe = pipeline(model="facebook/opt-1.3b", torch_dtype=torch.bfloat16, device_map="auto") +output = pipe("This is a cool example!", do_sample=True, top_p=0.95) +``` + +如果安装 `bitsandbytes` 并添加参数 `load_in_8bit=True`,您还可以传递8位加载的模型。 + + +```py +# pip install accelerate bitsandbytes +import torch +from transformers import pipeline + +pipe = pipeline(model="facebook/opt-1.3b", device_map="auto", model_kwargs={"load_in_8bit": True}) +output = pipe("This is a cool example!", do_sample=True, top_p=0.95) +``` + +请注意,您可以将`checkpoint `替换为任何支持大模型加载的Hugging Face模型,比如BLOOM! + From 33f98cfded1724e12be11a3b3333e4821de9bbbe Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 23 Oct 2023 18:54:00 +0200 Subject: [PATCH 13/38] Remove ambiguous `padding_mask` and instead use a 2D->4D Attn Mask Mapper (#26792) * [Attn Mask Converter] refactor attn mask * up * Apply suggestions from code review Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * improve * rename * better cache * renaming * improve more * improve * fix bug * finalize * make style & make fix-copies * correct more * start moving attention_mask * fix llama * improve falcon * up * improve more * improve more * Update src/transformers/models/owlv2/modeling_owlv2.py * make style * make style * rename to converter * Apply suggestions from code review --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- .../open_llama/modeling_open_llama.py | 2 +- .../models/falcon/modeling_falcon.py | 285 ++++++++++++----- .../models/llama/modeling_llama.py | 271 +++++++++++----- .../models/mistral/modeling_mistral.py | 298 ++++++++++++------ .../models/persimmon/modeling_persimmon.py | 1 - tests/models/llama/test_modeling_llama.py | 143 +++++++++ 6 files changed, 735 insertions(+), 265 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 6853f5333f137c..3ace323e822414 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -560,7 +560,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e9dca6df989472..6eaeed4199771c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -15,6 +15,7 @@ """PyTorch Falcon model.""" import math +import warnings from typing import Optional, Tuple, Union import torch @@ -76,9 +77,9 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -88,6 +89,143 @@ def _get_unpad_data(padding_mask): ) +# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter +class AttnMaskConverter: + """ + A utility attention mask class that allows: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + + return expanded_4d_mask + + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + # TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities class FalconRotaryEmbedding(nn.Module): """Implementation of RotaryEmbedding from GPT-NeoX. @@ -311,6 +449,7 @@ def __init__(self, config: FalconConfig): self.head_dim = self.hidden_size // self.num_heads self.split_size = self.hidden_size self.hidden_dropout = config.hidden_dropout + self.is_causal = True if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( @@ -431,8 +570,13 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -465,9 +609,6 @@ def forward( else: present = None - float_min = torch.finfo(query_layer.dtype).min - attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype) - query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) @@ -482,16 +623,14 @@ def forward( ) attn_output = F.scaled_dot_product_attention( - query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False + query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False ) attention_scores = None else: attention_scores = query_layer_ @ key_layer_.transpose(-1, -2) attention_scores /= math.sqrt(self.head_dim) - attention_scores = F.softmax( - attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype - ) + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) attn_output = attention_scores @ value_layer_ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) @@ -517,12 +656,12 @@ def forward( if input_dtype == torch.float16 or input_dtype == torch.bfloat16: attention_scores = attention_scores.to(torch.float32) # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by - # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically + # adding (alibi * self.inv_norm_factor) to attention_mask. I think this would be mathematically # equivalent and more performant, but there might be a numerical difference. If you're reading this # and you'd like to experiment and maybe file a PR, feel free! attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype) + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) @@ -563,8 +702,16 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -630,7 +777,7 @@ def forward( value_layer = value_layer.to(target_dtype) attn_output = self._flash_attention_forward( - query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout + query_layer, key_layer, value_layer, attention_mask, query_length, dropout=attn_dropout ) attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) @@ -643,7 +790,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -656,7 +803,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -665,10 +812,10 @@ def _flash_attention_forward( The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -684,7 +831,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) @@ -696,8 +843,8 @@ def _flash_attention_forward( return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( @@ -722,8 +869,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -752,7 +899,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FalconDecoderLayer(nn.Module): - def __init__(self, config: FalconConfig): + def __init__(self, config: FalconConfig, attn_mask_converter=None): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -786,8 +933,13 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states if self.config.new_decoder_architecture: @@ -806,7 +958,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, - padding_mask=padding_mask, + **kwargs, ) attention_output = attn_outputs[0] @@ -1001,6 +1153,10 @@ def __init__(self, config: FalconConfig): # Embedding + LN Embedding self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) + # create attention mask cache that trickles down to each attention layer + # so that the attention_mask cache can be shared among layers + self.attn_mask_converter = AttnMaskConverter(is_causal=True) + # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) @@ -1015,37 +1171,6 @@ def __init__(self, config: FalconConfig): def get_input_embeddings(self): return self.word_embeddings - @staticmethod - def _prepare_attn_mask( - attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int - ) -> torch.BoolTensor: - # Create a causal mask - # The attention mask we receive as input should cover the whole extended sequence, including any past - # cache, so its shape should be [batch_size, seq_length + past_key_values_length] - # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length] - if input_shape[1] + past_key_values_length != attention_mask.shape[1]: - raise ValueError( - "Attention mask shape should be (batch_size, seq_length + past_key_values_length)" - f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" - f" {past_key_values_length}." - ) - combined_attention_mask = None - device = attention_mask.device - _, seq_length = input_shape - - if seq_length > 1: - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length] - expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings @@ -1114,19 +1239,16 @@ def forward( past_key_values_length = 0 if past_key_values[0] is not None: past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) - padding_mask = None - else: - attention_mask = attention_mask.to(hidden_states.device) - - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None if self.use_alibi: - alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + mask = ( + torch.ones( + (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long + ) + if attention_mask is None + else attention_mask + ) + alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) else: alibi = None if position_ids is None: @@ -1136,11 +1258,20 @@ def forward( ) position_ids = position_ids.unsqueeze(0) - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + key_value_length = seq_length + past_key_values_length + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.attn_mask_converter.to_4d( + attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = self.attn_mask_converter.to_causal_4d( + batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: @@ -1159,22 +1290,20 @@ def custom_forward(*inputs): create_custom_forward(block), hidden_states, alibi, - causal_mask, + attention_mask, position_ids, head_mask[i], - padding_mask, ) else: outputs = block( hidden_states, layer_past=layer_past, - attention_mask=causal_mask, + attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, - padding_mask=padding_mask, ) hidden_states = outputs[0] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b67719ac327162..541455d86afde8 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -19,6 +19,7 @@ # limitations under the License. """ PyTorch LLaMA model.""" import math +import warnings from typing import List, Optional, Tuple, Union import torch @@ -51,9 +52,9 @@ _CONFIG_FOR_DOC = "LlamaConfig" -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -63,37 +64,140 @@ def _get_unpad_data(padding_mask): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): +class AttnMaskConverter: """ - Make causal mask used for bi-directional self-attention. + A utility attention mask class that allows: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + return expanded_4d_mask - inverted_mask = 1.0 - expanded_mask + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class LlamaRMSNorm(nn.Module): @@ -272,6 +376,7 @@ def __init__(self, config: LlamaConfig): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta + self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -322,8 +427,13 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -420,14 +530,22 @@ class LlamaFlashAttention2(LlamaAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -492,7 +610,7 @@ def forward( value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() @@ -504,7 +622,7 @@ def forward( return attn_output, attn_weights, past_key_value def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -517,7 +635,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -526,10 +644,10 @@ def _flash_attention_forward( The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -545,7 +663,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) @@ -556,8 +674,8 @@ def _flash_attention_forward( return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( @@ -582,8 +700,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -616,13 +734,13 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -631,6 +749,10 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states @@ -644,7 +766,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, + **kwargs, ) hidden_states = residual + hidden_states @@ -791,6 +913,10 @@ def __init__(self, config: LlamaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + # create attention mask cache that trickles down to each attention layer + # so that the attention_mask cache can be shared among layers + self.attn_mask_converter = AttnMaskConverter(is_causal=True) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -805,30 +931,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -854,18 +956,15 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - batch_size, seq_length = input_ids.shape + batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape + batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -876,22 +975,23 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - padding_mask = None + + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: - if 0 in attention_mask: - padding_mask = attention_mask + key_value_length = seq_length + past_key_values_length + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.attn_mask_converter.to_4d( + attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype + ) else: - padding_mask = None - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + attention_mask = self.attn_mask_converter.to_causal_4d( + batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + # embed positions hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: @@ -917,7 +1017,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -932,7 +1032,6 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d650d60b8a553e..f735c471268136 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -20,6 +20,7 @@ """ PyTorch Mistral model.""" import inspect import math +import warnings from typing import List, Optional, Tuple, Union import torch @@ -53,10 +54,147 @@ _CONFIG_FOR_DOC = "MistralConfig" +# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter +class AttnMaskConverter: + """ + A utility attention mask class that allows: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + + return expanded_4d_mask + + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + # Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -66,33 +204,6 @@ def _get_unpad_data(padding_mask): ) -def _make_sliding_window_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: int = 4096, -): - """ - Make causal mask used for sliding window attention - """ - bsz, tgt_len = input_ids_shape - - tensor = torch.full( - (tgt_len, tgt_len), - fill_value=1, - device=device, - ) - mask = torch.tril(tensor, diagonal=0) - # make the mask banded to account for sliding window - mask = torch.triu(mask, diagonal=-sliding_window) - mask = torch.log(mask).to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - # Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ @@ -223,6 +334,7 @@ def __init__(self, config: MistralConfig): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta + self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -251,8 +363,12 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -332,8 +448,15 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -385,9 +508,9 @@ def forward( past_key_value = (past_key, past_value) - if padding_mask is not None: - padding_mask = padding_mask[:, slicing_tokens:] - padding_mask = torch.cat([padding_mask, torch.ones_like(padding_mask[:, -1:])], dim=-1) + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) @@ -433,7 +556,7 @@ def forward( query_states, key_states, value_states, - padding_mask, + attention_mask, q_len, dropout=dropout_rate, use_sliding_windows=use_sliding_windows, @@ -452,7 +575,7 @@ def _flash_attention_forward( query_states, key_states, value_states, - padding_mask, + attention_mask, query_length, dropout=0.0, softmax_scale=None, @@ -469,7 +592,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -480,10 +603,10 @@ def _flash_attention_forward( Whether to activate sliding window attention. """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -500,7 +623,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, ) else: attn_output_unpad = flash_attn_varlen_func( @@ -513,7 +636,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) @@ -536,16 +659,16 @@ def _flash_attention_forward( return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape # On the first iteration we need to properly re-create the padding mask # by slicing it on the proper place - if kv_seq_len != padding_mask.shape[-1]: - padding_mask_num_tokens = padding_mask.shape[-1] - padding_mask = padding_mask[:, padding_mask_num_tokens - kv_seq_len :] + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) @@ -566,8 +689,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -600,13 +723,17 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -628,7 +755,6 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, ) hidden_states = residual + hidden_states @@ -775,6 +901,10 @@ def __init__(self, config: MistralConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + # create attention mask cache that trickles down to each attention layer + # so that the attention_mask cache can be shared among layers + self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -789,32 +919,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_sliding_window_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) def forward( self, @@ -865,23 +969,13 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - padding_mask = None - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - elif 0 in attention_mask: - padding_mask = attention_mask - if ( - padding_mask is not None + attention_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled and past_key_values is not None ): - is_padding_right = padding_mask[:, -1].sum().item() != batch_size + is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" @@ -889,13 +983,20 @@ def forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + key_value_length = seq_length + past_key_values_length + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.attn_mask_converter.to_4d( + attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = self.attn_mask_converter.to_causal_4d( + batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) hidden_states = inputs_embeds @@ -922,7 +1023,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -940,7 +1041,6 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a0bc5726382336..d73cc4484484fa 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -548,7 +548,6 @@ class PersimmonModel(PersimmonPreTrainedModel): config: PersimmonConfig """ - # Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps def __init__(self, config: PersimmonConfig): super().__init__(config) self.padding_idx = config.pad_token_id diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 2402986900fda6..df41c6c5f52044 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -39,6 +39,149 @@ LlamaModel, LlamaTokenizer, ) + from transformers.models.llama.modeling_llama import AttnMaskConverter + + +@require_torch +class AttentionMaskTester(unittest.TestCase): + def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d): + mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len)) + mask_4d_values = mask_4d[:, 0][mask_indices] + is_inf = mask_4d_values == -float("inf") + is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min + assert torch.logical_or(is_inf, is_min).all() + + def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3): + mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long) + + if additional_mask is not None: + for bsz_idx, seq_idx in additional_mask: + mask_2d[bsz_idx, seq_idx] = 0 + + mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len) + + assert mask_4d.shape == (bsz, 1, q_len, kv_len) + + context = mask_converter.sliding_window + if mask_converter.is_causal and context is None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = bsz * (q_len * (q_len - 1) // 2) + + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + if 0 in mask_2d: + # at least causal mask + maybe more + assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + elif not mask_converter.is_causal and context is None: + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == 0 + if 0 in mask_2d: + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + elif mask_converter.is_causal and context is not None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len) + num_tokens_masked = bsz * num_tokens_masked + + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + if 0 in mask_2d: + # at least causal mask + maybe more + assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + + def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3): + mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device) + + if q_len == 1 and mask_converter.sliding_window is None: + # no causal mask if q_len is 1 + assert mask_4d is None + return + + context = mask_converter.sliding_window + if mask_converter.is_causal and context is None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = bsz * (q_len * (q_len - 1) // 2) + + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + elif not mask_converter.is_causal and context is None: + assert (mask_4d != 0).sum().cpu().item() == 0 + elif mask_converter.is_causal and context is not None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len) + num_tokens_masked = bsz * num_tokens_masked + + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + + def compute_num_context_mask(self, kv_len, context, q_len): + # This function computes the # of attention tokens that are added for + # the sliding window + c_mask_len = kv_len - context + num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2 + cut_mask_len = max(c_mask_len - q_len, 0) + num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2 + return num_mask_triangle - num_cut_mask + + def test_2d_to_4d_causal(self): + mask_converter = AttnMaskConverter(is_causal=True) + + # auto-regressive use case + self.check_to_4d(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_4d(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_2d_to_4d(self): + torch.ones((3, 7), device=torch_device, dtype=torch.long) + mask_converter = AttnMaskConverter(is_causal=False) + + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_2d_to_4d_causal_sliding(self): + torch.ones((3, 7), device=torch_device, dtype=torch.long) + mask_converter = AttnMaskConverter(is_causal=True, sliding_window=5) + + # auto-regressive use case + self.check_to_4d(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_4d(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_causal_mask(self): + mask_converter = AttnMaskConverter(is_causal=True) + + # auto-regressive use case + self.check_to_causal(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_causal(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_causal(mask_converter, q_len=7, kv_len=7) + + def test_causal_mask_sliding(self): + mask_converter = AttnMaskConverter(is_causal=True, sliding_window=3) + + # auto-regressive use case + self.check_to_causal(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_causal(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_causal(mask_converter, q_len=7, kv_len=7) class LlamaModelTester: From 19ae0505aefcfcfbfeedaf05517d6a632ee7707d Mon Sep 17 00:00:00 2001 From: Yeyang <76979429+yyLeaves@users.noreply.github.com> Date: Mon, 23 Oct 2023 17:35:17 +0000 Subject: [PATCH 14/38] =?UTF-8?q?=F0=9F=8C=90=20[i18n-ZH]=20Translate=20mu?= =?UTF-8?q?ltilingual=20into=20Chinese=20(#26935)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit translate multilingual into Chinese Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/zh/_toctree.yml | 2 + docs/source/zh/multilingual.md | 178 +++++++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 docs/source/zh/multilingual.md diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 67cbc397bb0d62..0609a02ce73739 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -15,4 +15,6 @@ - sections: - local: fast_tokenizers title: 使用 🤗 Tokenizers 中的分词器 + - local: multilingual + title: 使用多语言模型进行推理 title: 开发者指南 diff --git a/docs/source/zh/multilingual.md b/docs/source/zh/multilingual.md new file mode 100644 index 00000000000000..7e8ab1336d9933 --- /dev/null +++ b/docs/source/zh/multilingual.md @@ -0,0 +1,178 @@ + + +# 用于推理的多语言模型 + +[[open-in-colab]] + +🤗 Transformers 中有多种多语言模型,它们的推理用法与单语言模型不同。但是,并非*所有*的多语言模型用法都不同。一些模型,例如 [bert-base-multilingual-uncased](https://huggingface.co/bert-base-multilingual-uncased) 就可以像单语言模型一样使用。本指南将向您展示如何使用不同用途的多语言模型进行推理。 + +## XLM + +XLM 有十个不同的检查点,其中只有一个是单语言的。剩下的九个检查点可以归为两类:使用语言嵌入的检查点和不使用语言嵌入的检查点。 + +### 带有语言嵌入的 XLM + +以下 XLM 模型使用语言嵌入来指定推理中使用的语言: + +- `xlm-mlm-ende-1024` (掩码语言建模,英语-德语) +- `xlm-mlm-enfr-1024` (掩码语言建模,英语-法语) +- `xlm-mlm-enro-1024` (掩码语言建模,英语-罗马尼亚语) +- `xlm-mlm-xnli15-1024` (掩码语言建模,XNLI 数据集语言) +- `xlm-mlm-tlm-xnli15-1024` (掩码语言建模+翻译,XNLI 数据集语言) +- `xlm-clm-enfr-1024` (因果语言建模,英语-法语) +- `xlm-clm-ende-1024` (因果语言建模,英语-德语) + +语言嵌入被表示一个张量,其形状与传递给模型的 `input_ids` 相同。这些张量中的值取决于所使用的语言,并由分词器的 `lang2id` 和 `id2lang` 属性识别。 + +在此示例中,加载 `xlm-clm-enfr-1024` 检查点(因果语言建模,英语-法语): + +```py +>>> import torch +>>> from transformers import XLMTokenizer, XLMWithLMHeadModel + +>>> tokenizer = XLMTokenizer.from_pretrained("xlm-clm-enfr-1024") +>>> model = XLMWithLMHeadModel.from_pretrained("xlm-clm-enfr-1024") +``` + +分词器的 `lang2id` 属性显示了该模型的语言及其对应的id: + +```py +>>> print(tokenizer.lang2id) +{'en': 0, 'fr': 1} +``` + +接下来,创建一个示例输入: + +```py +>>> input_ids = torch.tensor([tokenizer.encode("Wikipedia was used to")]) # batch size 为 1 +``` + +将语言 id 设置为 `"en"` 并用其定义语言嵌入。语言嵌入是一个用 `0` 填充的张量,这个张量应该与 `input_ids` 大小相同。 + +```py +>>> language_id = tokenizer.lang2id["en"] # 0 +>>> langs = torch.tensor([language_id] * input_ids.shape[1]) # torch.tensor([0, 0, 0, ..., 0]) + +>>> # 我们将其 reshape 为 (batch_size, sequence_length) 大小 +>>> langs = langs.view(1, -1) # 现在的形状是 [1, sequence_length] (我们的 batch size 为 1) +``` + +现在,你可以将 `input_ids` 和语言嵌入传递给模型: + +```py +>>> outputs = model(input_ids, langs=langs) +``` + +[run_generation.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-generation/run_generation.py) 脚本可以使用 `xlm-clm` 检查点生成带有语言嵌入的文本。 + +### 不带语言嵌入的 XLM + +以下 XLM 模型在推理时不需要语言嵌入: + +- `xlm-mlm-17-1280` (掩码语言建模,支持 17 种语言) +- `xlm-mlm-100-1280` (掩码语言建模,支持 100 种语言) + +与之前的 XLM 检查点不同,这些模型用于通用句子表示。 + +## BERT + +以下 BERT 模型可用于多语言任务: + +- `bert-base-multilingual-uncased` (掩码语言建模 + 下一句预测,支持 102 种语言) +- `bert-base-multilingual-cased` (掩码语言建模 + 下一句预测,支持 104 种语言) + +这些模型在推理时不需要语言嵌入。它们应该能够从上下文中识别语言并进行相应的推理。 + +## XLM-RoBERTa + +以下 XLM-RoBERTa 模型可用于多语言任务: + +- `xlm-roberta-base` (掩码语言建模,支持 100 种语言) +- `xlm-roberta-large` (掩码语言建模,支持 100 种语言) + +XLM-RoBERTa 使用 100 种语言的 2.5TB 新创建和清理的 CommonCrawl 数据进行了训练。与之前发布的 mBERT 或 XLM 等多语言模型相比,它在分类、序列标记和问答等下游任务上提供了更强大的优势。 + +## M2M100 + +以下 M2M100 模型可用于多语言翻译: + +- `facebook/m2m100_418M` (翻译) +- `facebook/m2m100_1.2B` (翻译) + +在此示例中,加载 `facebook/m2m100_418M` 检查点以将中文翻译为英文。你可以在分词器中设置源语言: + +```py +>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer + +>>> en_text = "Do not meddle in the affairs of wizards, for they are subtle and quick to anger." +>>> chinese_text = "不要插手巫師的事務, 因為他們是微妙的, 很快就會發怒." + +>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="zh") +>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M") +``` + +对文本进行分词: + +```py +>>> encoded_zh = tokenizer(chinese_text, return_tensors="pt") +``` + +M2M100 强制将目标语言 id 作为第一个生成的标记,以进行到目标语言的翻译。在 `generate` 方法中将 `forced_bos_token_id` 设置为 `en` 以翻译成英语: + +```py +>>> generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id("en")) +>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) +'Do not interfere with the matters of the witches, because they are delicate and will soon be angry.' +``` + +## MBart + +以下 MBart 模型可用于多语言翻译: + +- `facebook/mbart-large-50-one-to-many-mmt` (一对多多语言机器翻译,支持 50 种语言) +- `facebook/mbart-large-50-many-to-many-mmt` (多对多多语言机器翻译,支持 50 种语言) +- `facebook/mbart-large-50-many-to-one-mmt` (多对一多语言机器翻译,支持 50 种语言) +- `facebook/mbart-large-50` (多语言翻译,支持 50 种语言) +- `facebook/mbart-large-cc25` + +在此示例中,加载 `facebook/mbart-large-50-many-to-many-mmt` 检查点以将芬兰语翻译为英语。 你可以在分词器中设置源语言: + +```py +>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + +>>> en_text = "Do not meddle in the affairs of wizards, for they are subtle and quick to anger." +>>> fi_text = "Älä sekaannu velhojen asioihin, sillä ne ovat hienovaraisia ja nopeasti vihaisia." + +>>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", src_lang="fi_FI") +>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") +``` + +对文本进行分词: + +```py +>>> encoded_en = tokenizer(en_text, return_tensors="pt") +``` + +MBart 强制将目标语言 id 作为第一个生成的标记,以进行到目标语言的翻译。在 `generate` 方法中将 `forced_bos_token_id` 设置为 `en` 以翻译成英语: + +```py +>>> generated_tokens = model.generate(**encoded_en, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]) +>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) +"Don't interfere with the wizard's affairs, because they are subtle, will soon get angry." +``` + +如果你使用的是 `facebook/mbart-large-50-many-to-one-mmt` 检查点,则无需强制目标语言 id 作为第一个生成的令牌,否则用法是相同的。 From b0d1d7f71a985a4a57fb8e71e9d4c04e25529505 Mon Sep 17 00:00:00 2001 From: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com> Date: Tue, 24 Oct 2023 01:36:24 +0800 Subject: [PATCH 15/38] translate `preprocessing.md` to Chinese (#26955) * translate preprocessing.md to Chinese * update files fixing problems mentioned in review * update files fixing problems mentioned in review --------- Co-authored-by: jiaqiw --- docs/source/zh/_toctree.yml | 2 + docs/source/zh/preprocessing.md | 541 ++++++++++++++++++++++++++++++++ 2 files changed, 543 insertions(+) create mode 100644 docs/source/zh/preprocessing.md diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 0609a02ce73739..933f6f78b16ea0 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -9,6 +9,8 @@ - sections: - local: accelerate title: 加速分布式训练 + - local: preprocessing + title: 预处理 - local: pipeline_tutorial title: pipeline教程 title: 教程 diff --git a/docs/source/zh/preprocessing.md b/docs/source/zh/preprocessing.md new file mode 100644 index 00000000000000..95b799989c9197 --- /dev/null +++ b/docs/source/zh/preprocessing.md @@ -0,0 +1,541 @@ + + +# 预处理 + +[[open-in-colab]] + +在您可以在数据集上训练模型之前,数据需要被预处理为期望的模型输入格式。无论您的数据是文本、图像还是音频,它们都需要被转换并组合成批量的张量。🤗 Transformers 提供了一组预处理类来帮助准备数据以供模型使用。在本教程中,您将了解以下内容: + +* 对于文本,使用[分词器](./main_classes/tokenizer)(`Tokenizer`)将文本转换为一系列标记(`tokens`),并创建`tokens`的数字表示,将它们组合成张量。 +* 对于语音和音频,使用[特征提取器](./main_classes/feature_extractor)(`Feature extractor`)从音频波形中提取顺序特征并将其转换为张量。 +* 图像输入使用[图像处理器](./main_classes/image)(`ImageProcessor`)将图像转换为张量。 +* 多模态输入,使用[处理器](./main_classes/processors)(`Processor`)结合了`Tokenizer`和`ImageProcessor`或`Processor`。 + + + +`AutoProcessor` **始终**有效的自动选择适用于您使用的模型的正确`class`,无论您使用的是`Tokenizer`、`ImageProcessor`、`Feature extractor`还是`Processor`。 + + + +在开始之前,请安装🤗 Datasets,以便您可以加载一些数据集来进行实验: + + +```bash +pip install datasets +``` + +## 自然语言处理 + + + +处理文本数据的主要工具是[Tokenizer](main_classes/tokenizer)。`Tokenizer`根据一组规则将文本拆分为`tokens`。然后将这些`tokens`转换为数字,然后转换为张量,成为模型的输入。模型所需的任何附加输入都由`Tokenizer`添加。 + + + +如果您计划使用预训练模型,重要的是使用与之关联的预训练`Tokenizer`。这确保文本的拆分方式与预训练语料库相同,并在预训练期间使用相同的标记-索引的对应关系(通常称为*词汇表*-`vocab`)。 + + + +开始使用[`AutoTokenizer.from_pretrained`]方法加载一个预训练`tokenizer`。这将下载模型预训练的`vocab`: + + +```py +>>> from transformers import AutoTokenizer + +>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") +``` + +然后将您的文本传递给`tokenizer`: + + +```py +>>> encoded_input = tokenizer("Do not meddle in the affairs of wizards, for they are subtle and quick to anger.") +>>> print(encoded_input) +{'input_ids': [101, 2079, 2025, 19960, 10362, 1999, 1996, 3821, 1997, 16657, 1010, 2005, 2027, 2024, 11259, 1998, 4248, 2000, 4963, 1012, 102], + 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} +``` + +`tokenizer`返回一个包含三个重要对象的字典: + +* [input_ids](glossary#input-ids) 是与句子中每个`token`对应的索引。 +* [attention_mask](glossary#attention-mask) 指示是否应该关注一个`toekn`。 +* [token_type_ids](glossary#token-type-ids) 在存在多个序列时标识一个`token`属于哪个序列。 + +通过解码 `input_ids` 来返回您的输入: + + +```py +>>> tokenizer.decode(encoded_input["input_ids"]) +'[CLS] Do not meddle in the affairs of wizards, for they are subtle and quick to anger. [SEP]' +``` + +如您所见,`tokenizer`向句子中添加了两个特殊`token` - `CLS` 和 `SEP`(分类器和分隔符)。并非所有模型都需要特殊`token`,但如果需要,`tokenizer`会自动为您添加。 + +如果有多个句子需要预处理,将它们作为列表传递给`tokenizer`: + + +```py +>>> batch_sentences = [ +... "But what about second breakfast?", +... "Don't think he knows about second breakfast, Pip.", +... "What about elevensies?", +... ] +>>> encoded_inputs = tokenizer(batch_sentences) +>>> print(encoded_inputs) +{'input_ids': [[101, 1252, 1184, 1164, 1248, 6462, 136, 102], + [101, 1790, 112, 189, 1341, 1119, 3520, 1164, 1248, 6462, 117, 21902, 1643, 119, 102], + [101, 1327, 1164, 5450, 23434, 136, 102]], + 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], + 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1]]} +``` + +### 填充 + +句子的长度并不总是相同,这可能会成为一个问题,因为模型输入的张量需要具有统一的形状。填充是一种策略,通过在较短的句子中添加一个特殊的`padding token`,以确保张量是矩形的。 + +将 `padding` 参数设置为 `True`,以使批次中较短的序列填充到与最长序列相匹配的长度: + +```py +>>> batch_sentences = [ +... "But what about second breakfast?", +... "Don't think he knows about second breakfast, Pip.", +... "What about elevensies?", +... ] +>>> encoded_input = tokenizer(batch_sentences, padding=True) +>>> print(encoded_input) +{'input_ids': [[101, 1252, 1184, 1164, 1248, 6462, 136, 102, 0, 0, 0, 0, 0, 0, 0], + [101, 1790, 112, 189, 1341, 1119, 3520, 1164, 1248, 6462, 117, 21902, 1643, 119, 102], + [101, 1327, 1164, 5450, 23434, 136, 102, 0, 0, 0, 0, 0, 0, 0, 0]], + 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]} +``` + +第一句和第三句因为较短,通过`0`进行填充,。 + +### 截断 + +另一方面,有时候一个序列可能对模型来说太长了。在这种情况下,您需要将序列截断为更短的长度。 + +将 `truncation` 参数设置为 `True`,以将序列截断为模型接受的最大长度: + + +```py +>>> batch_sentences = [ +... "But what about second breakfast?", +... "Don't think he knows about second breakfast, Pip.", +... "What about elevensies?", +... ] +>>> encoded_input = tokenizer(batch_sentences, padding=True, truncation=True) +>>> print(encoded_input) +{'input_ids': [[101, 1252, 1184, 1164, 1248, 6462, 136, 102, 0, 0, 0, 0, 0, 0, 0], + [101, 1790, 112, 189, 1341, 1119, 3520, 1164, 1248, 6462, 117, 21902, 1643, 119, 102], + [101, 1327, 1164, 5450, 23434, 136, 102, 0, 0, 0, 0, 0, 0, 0, 0]], + 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]} +``` + + + +查看[填充和截断](./pad_truncation)概念指南,了解更多有关填充和截断参数的信息。 + + + +### 构建张量 + +最后,`tokenizer`可以返回实际输入到模型的张量。 + +将 `return_tensors` 参数设置为 `pt`(对于PyTorch)或 `tf`(对于TensorFlow): + + + + + +```py +>>> batch_sentences = [ +... "But what about second breakfast?", +... "Don't think he knows about second breakfast, Pip.", +... "What about elevensies?", +... ] +>>> encoded_input = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors="pt") +>>> print(encoded_input) +{'input_ids': tensor([[101, 1252, 1184, 1164, 1248, 6462, 136, 102, 0, 0, 0, 0, 0, 0, 0], + [101, 1790, 112, 189, 1341, 1119, 3520, 1164, 1248, 6462, 117, 21902, 1643, 119, 102], + [101, 1327, 1164, 5450, 23434, 136, 102, 0, 0, 0, 0, 0, 0, 0, 0]]), + 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])} +``` + + + +```py +>>> batch_sentences = [ +... "But what about second breakfast?", +... "Don't think he knows about second breakfast, Pip.", +... "What about elevensies?", +... ] +>>> encoded_input = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors="tf") +>>> print(encoded_input) +{'input_ids': , + 'token_type_ids': , + 'attention_mask': } +``` + + + +## 音频 + +对于音频任务,您需要[feature extractor](main_classes/feature_extractor)来准备您的数据集以供模型使用。`feature extractor`旨在从原始音频数据中提取特征,并将它们转换为张量。 + +加载[MInDS-14](https://huggingface.co/datasets/PolyAI/minds14)数据集(有关如何加载数据集的更多详细信息,请参阅🤗 [Datasets教程](https://huggingface.co/docs/datasets/load_hub.html))以了解如何在音频数据集中使用`feature extractor`: + + +```py +>>> from datasets import load_dataset, Audio + +>>> dataset = load_dataset("PolyAI/minds14", name="en-US", split="train") +``` + +访问 `audio` 列的第一个元素以查看输入。调用 `audio` 列会自动加载和重新采样音频文件: + +```py +>>> dataset[0]["audio"] +{'array': array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414, + 0. , 0. ], dtype=float32), + 'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav', + 'sampling_rate': 8000} +``` + +这会返回三个对象: + +* `array` 是加载的语音信号 - 并在必要时重新采为`1D array`。 +* `path` 指向音频文件的位置。 +* `sampling_rate` 是每秒测量的语音信号数据点数量。 + +对于本教程,您将使用[Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base)模型。查看模型卡片,您将了解到Wav2Vec2是在16kHz采样的语音音频数据上预训练的。重要的是,您的音频数据的采样率要与用于预训练模型的数据集的采样率匹配。如果您的数据的采样率不同,那么您需要对数据进行重新采样。 + +1. 使用🤗 Datasets的[`~datasets.Dataset.cast_column`]方法将采样率提升到16kHz: + +```py +>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000)) +``` + +2. 再次调用 `audio` 列以重新采样音频文件: + + +```py +>>> dataset[0]["audio"] +{'array': array([ 2.3443763e-05, 2.1729663e-04, 2.2145823e-04, ..., + 3.8356509e-05, -7.3497440e-06, -2.1754686e-05], dtype=float32), + 'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav', + 'sampling_rate': 16000} +``` + +接下来,加载一个`feature extractor`以对输入进行标准化和填充。当填充文本数据时,会为较短的序列添加 `0`。相同的理念适用于音频数据。`feature extractor`添加 `0` - 被解释为静音 - 到`array` 。 + +使用 [`AutoFeatureExtractor.from_pretrained`] 加载`feature extractor`: + + +```py +>>> from transformers import AutoFeatureExtractor + +>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") +``` + +将音频 `array` 传递给`feature extractor`。我们还建议在`feature extractor`中添加 `sampling_rate` 参数,以更好地调试可能发生的静音错误: + + +```py +>>> audio_input = [dataset[0]["audio"]["array"]] +>>> feature_extractor(audio_input, sampling_rate=16000) +{'input_values': [array([ 3.8106556e-04, 2.7506407e-03, 2.8015103e-03, ..., + 5.6335266e-04, 4.6588284e-06, -1.7142107e-04], dtype=float32)]} +``` + +就像`tokenizer`一样,您可以应用填充或截断来处理批次中的可变序列。请查看这两个音频样本的序列长度: + + +```py +>>> dataset[0]["audio"]["array"].shape +(173398,) + +>>> dataset[1]["audio"]["array"].shape +(106496,) +``` + +创建一个函数来预处理数据集,以使音频样本具有相同的长度。通过指定最大样本长度,`feature extractor`将填充或截断序列以使其匹配: + + +```py +>>> def preprocess_function(examples): +... audio_arrays = [x["array"] for x in examples["audio"]] +... inputs = feature_extractor( +... audio_arrays, +... sampling_rate=16000, +... padding=True, +... max_length=100000, +... truncation=True, +... ) +... return inputs +``` + +将`preprocess_function`应用于数据集中的前几个示例: + + +```py +>>> processed_dataset = preprocess_function(dataset[:5]) +``` + +现在样本长度是相同的,并且与指定的最大长度匹配。您现在可以将经过处理的数据集传递给模型了! + + +```py +>>> processed_dataset["input_values"][0].shape +(100000,) + +>>> processed_dataset["input_values"][1].shape +(100000,) +``` + +## 计算机视觉 + +对于计算机视觉任务,您需要一个[ image processor](main_classes/image_processor)来准备数据集以供模型使用。图像预处理包括多个步骤将图像转换为模型期望输入的格式。这些步骤包括但不限于调整大小、标准化、颜色通道校正以及将图像转换为张量。 + + + +图像预处理通常遵循某种形式的图像增强。图像预处理和图像增强都会改变图像数据,但它们有不同的目的: + +* 图像增强可以帮助防止过拟合并增加模型的鲁棒性。您可以在数据增强方面充分发挥创造性 - 调整亮度和颜色、裁剪、旋转、调整大小、缩放等。但要注意不要改变图像的含义。 +* 图像预处理确保图像与模型预期的输入格式匹配。在微调计算机视觉模型时,必须对图像进行与模型训练时相同的预处理。 + +您可以使用任何您喜欢的图像增强库。对于图像预处理,请使用与模型相关联的`ImageProcessor`。 + + + +加载[food101](https://huggingface.co/datasets/food101)数据集(有关如何加载数据集的更多详细信息,请参阅🤗 [Datasets教程](https://huggingface.co/docs/datasets/load_hub.html))以了解如何在计算机视觉数据集中使用图像处理器: + + + +因为数据集相当大,请使用🤗 Datasets的`split`参数加载训练集中的少量样本! + + + + +```py +>>> from datasets import load_dataset + +>>> dataset = load_dataset("food101", split="train[:100]") +``` + +接下来,使用🤗 Datasets的[`Image`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=image#datasets.Image)功能查看图像: + + +```py +>>> dataset[0]["image"] +``` + +
+ +
+ +使用 [`AutoImageProcessor.from_pretrained`] 加载`image processor`: + +```py +>>> from transformers import AutoImageProcessor + +>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") +``` + +首先,让我们进行图像增强。您可以使用任何您喜欢的库,但在本教程中,我们将使用torchvision的[`transforms`](https://pytorch.org/vision/stable/transforms.html)模块。如果您有兴趣使用其他数据增强库,请参阅[Albumentations](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification_albumentations.ipynb)或[Kornia notebooks](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification_kornia.ipynb)中的示例。 + +1. 在这里,我们使用[`Compose`](https://pytorch.org/vision/master/generated/torchvision.transforms.Compose.html)将[`RandomResizedCrop`](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomResizedCrop.html)和 [`ColorJitter`](https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html)变换连接在一起。请注意,对于调整大小,我们可以从`image_processor`中获取图像尺寸要求。对于一些模型,精确的高度和宽度需要被定义,对于其他模型只需定义`shortest_edge`。 + + +```py +>>> from torchvision.transforms import RandomResizedCrop, ColorJitter, Compose + +>>> size = ( +... image_processor.size["shortest_edge"] +... if "shortest_edge" in image_processor.size +... else (image_processor.size["height"], image_processor.size["width"]) +... ) + +>>> _transforms = Compose([RandomResizedCrop(size), ColorJitter(brightness=0.5, hue=0.5)]) +``` + +2. 模型接受 [`pixel_values`](model_doc/visionencoderdecoder#transformers.VisionEncoderDecoderModel.forward.pixel_values) 作为输入。`ImageProcessor` 可以进行图像的标准化,并生成适当的张量。创建一个函数,将图像增强和图像预处理步骤组合起来处理批量图像,并生成 `pixel_values`: + + +```py +>>> def transforms(examples): +... images = [_transforms(img.convert("RGB")) for img in examples["image"]] +... examples["pixel_values"] = image_processor(images, do_resize=False, return_tensors="pt")["pixel_values"] +... return examples +``` + + + +在上面的示例中,我们设置`do_resize=False`,因为我们已经在图像增强转换中调整了图像的大小,并利用了适当的`image_processor`的`size`属性。如果您在图像增强期间不调整图像的大小,请将此参数排除在外。默认情况下`ImageProcessor`将处理调整大小。 + +如果希望将图像标准化步骤为图像增强的一部分,请使用`image_processor.image_mean`和`image_processor.image_std`。 + + + +3. 然后使用🤗 Datasets的[`set_transform`](https://huggingface.co/docs/datasets/process.html#format-transform)在运行时应用这些变换: + + +```py +>>> dataset.set_transform(transforms) +``` + +4. 现在,当您访问图像时,您将注意到`image processor`已添加了 `pixel_values`。您现在可以将经过处理的数据集传递给模型了! + + +```py +>>> dataset[0].keys() +``` + +这是在应用变换后的图像样子。图像已被随机裁剪,并其颜色属性发生了变化。 + + +```py +>>> import numpy as np +>>> import matplotlib.pyplot as plt + +>>> img = dataset[0]["pixel_values"] +>>> plt.imshow(img.permute(1, 2, 0)) +``` + +
+ +
+ + + +对于诸如目标检测、语义分割、实例分割和全景分割等任务,`ImageProcessor`提供了训练后处理方法。这些方法将模型的原始输出转换为有意义的预测,如边界框或分割地图。 + + + +### 填充 + +在某些情况下,例如,在微调[DETR](./model_doc/detr)时,模型在训练时应用了尺度增强。这可能导致批处理中的图像大小不同。您可以使用[`DetrImageProcessor.pad`]来指定自定义的`collate_fn`将图像批处理在一起。 + +```py +>>> def collate_fn(batch): +... pixel_values = [item["pixel_values"] for item in batch] +... encoding = image_processor.pad(pixel_values, return_tensors="pt") +... labels = [item["labels"] for item in batch] +... batch = {} +... batch["pixel_values"] = encoding["pixel_values"] +... batch["pixel_mask"] = encoding["pixel_mask"] +... batch["labels"] = labels +... return batch +``` + +## 多模态 + +对于涉及多模态输入的任务,您需要[processor](main_classes/processors)来为模型准备数据集。`processor`将两个处理对象-例如`tokenizer`和`feature extractor`-组合在一起。 + +加载[LJ Speech](https://huggingface.co/datasets/lj_speech)数据集(有关如何加载数据集的更多详细信息,请参阅🤗 [Datasets 教程](https://huggingface.co/docs/datasets/load_hub.html))以了解如何使用`processor`进行自动语音识别(ASR): + + +```py +>>> from datasets import load_dataset + +>>> lj_speech = load_dataset("lj_speech", split="train") +``` + +对于ASR(自动语音识别),主要关注`audio`和`text`,因此可以删除其他列: + + +```py +>>> lj_speech = lj_speech.map(remove_columns=["file", "id", "normalized_text"]) +``` + +现在查看`audio`和`text`列: + +```py +>>> lj_speech[0]["audio"] +{'array': array([-7.3242188e-04, -7.6293945e-04, -6.4086914e-04, ..., + 7.3242188e-04, 2.1362305e-04, 6.1035156e-05], dtype=float32), + 'path': '/root/.cache/huggingface/datasets/downloads/extracted/917ece08c95cf0c4115e45294e3cd0dee724a1165b7fc11798369308a465bd26/LJSpeech-1.1/wavs/LJ001-0001.wav', + 'sampling_rate': 22050} + +>>> lj_speech[0]["text"] +'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition' +``` + +请记住,您应始终[重新采样](preprocessing#audio)音频数据集的采样率,以匹配用于预训练模型数据集的采样率! + + +```py +>>> lj_speech = lj_speech.cast_column("audio", Audio(sampling_rate=16_000)) +``` + +使用[`AutoProcessor.from_pretrained`]加载一个`processor`: + + +```py +>>> from transformers import AutoProcessor + +>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") +``` + +1. 创建一个函数,用于将包含在 `array` 中的音频数据处理为 `input_values`,并将 `text` 标记为 `labels`。这些将是输入模型的数据: + +```py +>>> def prepare_dataset(example): +... audio = example["audio"] + +... example.update(processor(audio=audio["array"], text=example["text"], sampling_rate=16000)) + +... return example +``` + +2. 将 `prepare_dataset` 函数应用于一个示例: + +```py +>>> prepare_dataset(lj_speech[0]) +``` + +`processor`现在已经添加了 `input_values` 和 `labels`,并且采样率也正确降低为为16kHz。现在可以将处理后的数据集传递给模型! From f370bebdc352cd7c1bea2f88ae0c140ab694c5fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20Gabriel=20Gengo=20Louren=C3=A7o?= Date: Mon, 23 Oct 2023 15:34:27 -0300 Subject: [PATCH 16/38] Bugfix device map detr model (#26849) * Fixed replace_batch_norm when on meta device * lint fix * Adding coauthor Co-authored-by: Pi Esposito * Removed tests * Remove unused deps * Try to fix copy issue * try fix copy one more time * Reverted import changes --------- Co-authored-by: Pi Esposito --- .../conditional_detr/modeling_conditional_detr.py | 10 ++++++---- .../deformable_detr/modeling_deformable_detr.py | 10 ++++++---- src/transformers/models/deta/modeling_deta.py | 12 +++++++----- src/transformers/models/detr/modeling_detr.py | 10 ++++++---- .../modeling_table_transformer.py | 14 ++++++++++---- 5 files changed, 35 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 15f24084f46995..69937afefce453 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -322,10 +322,11 @@ def replace_batch_norm(model): if isinstance(module, nn.BatchNorm2d): new_module = ConditionalDetrFrozenBatchNorm2d(module.num_features) - new_module.weight.data.copy_(module.weight) - new_module.bias.data.copy_(module.bias) - new_module.running_mean.data.copy_(module.running_mean) - new_module.running_var.data.copy_(module.running_var) + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) model._modules[name] = new_module @@ -1145,6 +1146,7 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel): config_class = ConditionalDetrConfig base_model_prefix = "model" main_input_name = "pixel_values" + _no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"] def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index f541ca130544dd..ea4555d5ae217b 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -369,10 +369,11 @@ def replace_batch_norm(model): if isinstance(module, nn.BatchNorm2d): new_module = DeformableDetrFrozenBatchNorm2d(module.num_features) - new_module.weight.data.copy_(module.weight) - new_module.bias.data.copy_(module.bias) - new_module.running_mean.data.copy_(module.running_mean) - new_module.running_var.data.copy_(module.running_var) + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) model._modules[name] = new_module @@ -1061,6 +1062,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): config_class = DeformableDetrConfig base_model_prefix = "model" main_input_name = "pixel_values" + _no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"] def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 9cd29e94088730..1aab38c2891384 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -307,10 +307,11 @@ def replace_batch_norm(model): if isinstance(module, nn.BatchNorm2d): new_module = DetaFrozenBatchNorm2d(module.num_features) - new_module.weight.data.copy_(module.weight) - new_module.bias.data.copy_(module.bias) - new_module.running_mean.data.copy_(module.running_mean) - new_module.running_var.data.copy_(module.running_var) + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) model._modules[name] = new_module @@ -947,11 +948,12 @@ def forward(self, hidden_states: torch.Tensor): return hidden_states -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetr->Deta +# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetrConvEncoder->DetaBackboneWithPositionalEncodings,DeformableDetr->Deta class DetaPreTrainedModel(PreTrainedModel): config_class = DetaConfig base_model_prefix = "model" main_input_name = "pixel_values" + _no_split_modules = [r"DetaBackboneWithPositionalEncodings", r"DetaEncoderLayer", r"DetaDecoderLayer"] def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 3dda00a20082cc..d2b6ea07d7b7ca 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -316,10 +316,11 @@ def replace_batch_norm(model): if isinstance(module, nn.BatchNorm2d): new_module = DetrFrozenBatchNorm2d(module.num_features) - new_module.weight.data.copy_(module.weight) - new_module.bias.data.copy_(module.bias) - new_module.running_mean.data.copy_(module.running_mean) - new_module.running_var.data.copy_(module.running_var) + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) model._modules[name] = new_module @@ -901,6 +902,7 @@ class DetrPreTrainedModel(PreTrainedModel): config_class = DetrConfig base_model_prefix = "model" main_input_name = "pixel_values" + _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"] def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 8f59bd4b6e1785..b6012700ee7de0 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -251,10 +251,11 @@ def replace_batch_norm(model): if isinstance(module, nn.BatchNorm2d): new_module = TableTransformerFrozenBatchNorm2d(module.num_features) - new_module.weight.data.copy_(module.weight) - new_module.bias.data.copy_(module.bias) - new_module.running_mean.data.copy_(module.running_mean) - new_module.running_var.data.copy_(module.running_var) + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) model._modules[name] = new_module @@ -813,6 +814,11 @@ class TableTransformerPreTrainedModel(PreTrainedModel): config_class = TableTransformerConfig base_model_prefix = "model" main_input_name = "pixel_values" + _no_split_modules = [ + r"TableTransformerConvEncoder", + r"TableTransformerEncoderLayer", + r"TableTransformerDecoderLayer", + ] def _init_weights(self, module): std = self.config.init_std From 25c022d7c5c89860b86def2beb141012e597d86c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mert=20Yan=C4=B1k?= Date: Tue, 24 Oct 2023 01:36:42 +0300 Subject: [PATCH 17/38] Fix little typo (#27028) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 758675e29716b1..20813d78d3428a 100644 --- a/README.md +++ b/README.md @@ -498,7 +498,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[ViT Hybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VitDet](https://huggingface.co/docs/transformers/model_doc/vitdet)** (from Meta AI) released with the paper [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527) by Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He. 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. -1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) rreleased with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. +1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. From 32f799db0d625ec5cf82624ff2604c5a891ebf61 Mon Sep 17 00:00:00 2001 From: Yeyang <76979429+yyLeaves@users.noreply.github.com> Date: Mon, 23 Oct 2023 22:44:42 +0000 Subject: [PATCH 18/38] =?UTF-8?q?=F0=9F=8C=90=20[i18n-ZH]=20Translate=20cr?= =?UTF-8?q?eate=5Fa=5Fmodel.md=20into=20Chinese=20(#27026)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit docs(zh): translate create_a_model.md --- docs/source/zh/_toctree.yml | 2 + docs/source/zh/create_a_model.md | 389 +++++++++++++++++++++++++++++++ 2 files changed, 391 insertions(+) create mode 100644 docs/source/zh/create_a_model.md diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 933f6f78b16ea0..0f8bc9f3e86392 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -19,4 +19,6 @@ title: 使用 🤗 Tokenizers 中的分词器 - local: multilingual title: 使用多语言模型进行推理 + - local: create_a_model + title: 使用特定于模型的 API title: 开发者指南 diff --git a/docs/source/zh/create_a_model.md b/docs/source/zh/create_a_model.md new file mode 100644 index 00000000000000..b934708074d30d --- /dev/null +++ b/docs/source/zh/create_a_model.md @@ -0,0 +1,389 @@ + + +# 创建自定义架构 + +[`AutoClass`](model_doc/auto) 自动推断模型架构并下载预训练的配置和权重。一般来说,我们建议使用 `AutoClass` 生成与检查点(checkpoint)无关的代码。希望对特定模型参数有更多控制的用户,可以仅从几个基类创建自定义的 🤗 Transformers 模型。这对于任何有兴趣学习、训练或试验 🤗 Transformers 模型的人可能特别有用。通过本指南,深入了解如何不通过 `AutoClass` 创建自定义模型。了解如何: + +- 加载并自定义模型配置。 +- 创建模型架构。 +- 为文本创建慢速和快速分词器。 +- 为视觉任务创建图像处理器。 +- 为音频任务创建特征提取器。 +- 为多模态任务创建处理器。 + +## 配置 + +[配置](main_classes/configuration) 涉及到模型的具体属性。每个模型配置都有不同的属性;例如,所有 NLP 模型都共享 `hidden_size`、`num_attention_heads`、 `num_hidden_layers` 和 `vocab_size` 属性。这些属性用于指定构建模型时的注意力头数量或隐藏层层数。 + +访问 [`DistilBertConfig`] 以更近一步了解 [DistilBERT](model_doc/distilbert),检查它的属性: + +```py +>>> from transformers import DistilBertConfig + +>>> config = DistilBertConfig() +>>> print(config) +DistilBertConfig { + "activation": "gelu", + "attention_dropout": 0.1, + "dim": 768, + "dropout": 0.1, + "hidden_dim": 3072, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "model_type": "distilbert", + "n_heads": 12, + "n_layers": 6, + "pad_token_id": 0, + "qa_dropout": 0.1, + "seq_classif_dropout": 0.2, + "sinusoidal_pos_embds": false, + "transformers_version": "4.16.2", + "vocab_size": 30522 +} +``` + +[`DistilBertConfig`] 显示了构建基础 [`DistilBertModel`] 所使用的所有默认属性。所有属性都可以进行自定义,为实验创造了空间。例如,您可以将默认模型自定义为: + +- 使用 `activation` 参数尝试不同的激活函数。 +- 使用 `attention_dropout` 参数为 attention probabilities 使用更高的 dropout ratio。 + +```py +>>> my_config = DistilBertConfig(activation="relu", attention_dropout=0.4) +>>> print(my_config) +DistilBertConfig { + "activation": "relu", + "attention_dropout": 0.4, + "dim": 768, + "dropout": 0.1, + "hidden_dim": 3072, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "model_type": "distilbert", + "n_heads": 12, + "n_layers": 6, + "pad_token_id": 0, + "qa_dropout": 0.1, + "seq_classif_dropout": 0.2, + "sinusoidal_pos_embds": false, + "transformers_version": "4.16.2", + "vocab_size": 30522 +} +``` + +预训练模型的属性可以在 [`~PretrainedConfig.from_pretrained`] 函数中进行修改: + +```py +>>> my_config = DistilBertConfig.from_pretrained("distilbert-base-uncased", activation="relu", attention_dropout=0.4) +``` + +当你对模型配置满意时,可以使用 [`~PretrainedConfig.save_pretrained`] 来保存配置。你的配置文件将以 JSON 文件的形式存储在指定的保存目录中: + +```py +>>> my_config.save_pretrained(save_directory="./your_model_save_path") +``` + +要重用配置文件,请使用 [`~PretrainedConfig.from_pretrained`] 进行加载: + +```py +>>> my_config = DistilBertConfig.from_pretrained("./your_model_save_path/config.json") +``` + + + +你还可以将配置文件保存为字典,甚至只保存自定义配置属性与默认配置属性之间的差异!有关更多详细信息,请参阅 [配置](main_classes/configuration) 文档。 + + + +## 模型 + +接下来,创建一个[模型](main_classes/models)。模型,也可泛指架构,定义了每一层网络的行为以及进行的操作。配置中的 `num_hidden_layers` 等属性用于定义架构。每个模型都共享基类 [`PreTrainedModel`] 和一些常用方法,例如调整输入嵌入的大小和修剪自注意力头。此外,所有模型都是 [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)、[`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) 或 [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module) 的子类。这意味着模型与各自框架的用法兼容。 + + + +将自定义配置属性加载到模型中: + +```py +>>> from transformers import DistilBertModel + +>>> my_config = DistilBertConfig.from_pretrained("./your_model_save_path/config.json") +>>> model = DistilBertModel(my_config) +``` + +这段代码创建了一个具有随机参数而不是预训练权重的模型。在训练该模型之前,您还无法将该模型用于任何用途。训练是一项昂贵且耗时的过程。通常来说,最好使用预训练模型来更快地获得更好的结果,同时仅使用训练所需资源的一小部分。 + +使用 [`~PreTrainedModel.from_pretrained`] 创建预训练模型: + +```py +>>> model = DistilBertModel.from_pretrained("distilbert-base-uncased") +``` + +当加载预训练权重时,如果模型是由 🤗 Transformers 提供的,将自动加载默认模型配置。然而,如果你愿意,仍然可以将默认模型配置的某些或者所有属性替换成你自己的配置: + +```py +>>> model = DistilBertModel.from_pretrained("distilbert-base-uncased", config=my_config) +``` + + +将自定义配置属性加载到模型中: + +```py +>>> from transformers import TFDistilBertModel + +>>> my_config = DistilBertConfig.from_pretrained("./your_model_save_path/my_config.json") +>>> tf_model = TFDistilBertModel(my_config) +``` + +这段代码创建了一个具有随机参数而不是预训练权重的模型。在训练该模型之前,您还无法将该模型用于任何用途。训练是一项昂贵且耗时的过程。通常来说,最好使用预训练模型来更快地获得更好的结果,同时仅使用训练所需资源的一小部分。 + +使用 [`~TFPreTrainedModel.from_pretrained`] 创建预训练模型: + +```py +>>> tf_model = TFDistilBertModel.from_pretrained("distilbert-base-uncased") +``` + +当加载预训练权重时,如果模型是由 🤗 Transformers 提供的,将自动加载默认模型配置。然而,如果你愿意,仍然可以将默认模型配置的某些或者所有属性替换成自己的配置: + +```py +>>> tf_model = TFDistilBertModel.from_pretrained("distilbert-base-uncased", config=my_config) +``` + + + +### 模型头(Model heads) + +此时,你已经有了一个输出*隐藏状态*的基础 DistilBERT 模型。隐藏状态作为输入传递到模型头以生成最终输出。🤗 Transformers 为每个任务提供不同的模型头,只要模型支持该任务(即,您不能使用 DistilBERT 来执行像翻译这样的序列到序列任务)。 + + + +例如,[`DistilBertForSequenceClassification`] 是一个带有序列分类头(sequence classification head)的基础 DistilBERT 模型。序列分类头是池化输出之上的线性层。 + +```py +>>> from transformers import DistilBertForSequenceClassification + +>>> model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased") +``` + +通过切换到不同的模型头,可以轻松地将此检查点重复用于其他任务。对于问答任务,你可以使用 [`DistilBertForQuestionAnswering`] 模型头。问答头(question answering head)与序列分类头类似,不同点在于它是隐藏状态输出之上的线性层。 + +```py +>>> from transformers import DistilBertForQuestionAnswering + +>>> model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased") +``` + + +例如,[`TFDistilBertForSequenceClassification`] 是一个带有序列分类头(sequence classification head)的基础 DistilBERT 模型。序列分类头是池化输出之上的线性层。 + +```py +>>> from transformers import TFDistilBertForSequenceClassification + +>>> tf_model = TFDistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased") +``` + +通过切换到不同的模型头,可以轻松地将此检查点重复用于其他任务。对于问答任务,你可以使用 [`TFDistilBertForQuestionAnswering`] 模型头。问答头(question answering head)与序列分类头类似,不同点在于它是隐藏状态输出之上的线性层。 + +```py +>>> from transformers import TFDistilBertForQuestionAnswering + +>>> tf_model = TFDistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased") +``` + + + +## 分词器 + +在将模型用于文本数据之前,你需要的最后一个基类是 [tokenizer](main_classes/tokenizer),它用于将原始文本转换为张量。🤗 Transformers 支持两种类型的分词器: + +- [`PreTrainedTokenizer`]:分词器的Python实现 +- [`PreTrainedTokenizerFast`]:来自我们基于 Rust 的 [🤗 Tokenizer](https://huggingface.co/docs/tokenizers/python/latest/) 库的分词器。因为其使用了 Rust 实现,这种分词器类型的速度要快得多,尤其是在批量分词(batch tokenization)的时候。快速分词器还提供其他的方法,例如*偏移映射(offset mapping)*,它将标记(token)映射到其原始单词或字符。 + +这两种分词器都支持常用的方法,如编码和解码、添加新标记以及管理特殊标记。 + + + +并非每个模型都支持快速分词器。参照这张 [表格](index#supported-frameworks) 查看模型是否支持快速分词器。 + + + +如果您训练了自己的分词器,则可以从*词表*文件创建一个分词器: + +```py +>>> from transformers import DistilBertTokenizer + +>>> my_tokenizer = DistilBertTokenizer(vocab_file="my_vocab_file.txt", do_lower_case=False, padding_side="left") +``` + +请务必记住,自定义分词器生成的词表与预训练模型分词器生成的词表是不同的。如果使用预训练模型,则需要使用预训练模型的词表,否则输入将没有意义。 使用 [`DistilBertTokenizer`] 类创建具有预训练模型词表的分词器: + +```py +>>> from transformers import DistilBertTokenizer + +>>> slow_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") +``` + +使用 [`DistilBertTokenizerFast`] 类创建快速分词器: + +```py +>>> from transformers import DistilBertTokenizerFast + +>>> fast_tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") +``` + + + +默认情况下,[`AutoTokenizer`] 将尝试加载快速标记生成器。你可以通过在 `from_pretrained` 中设置 `use_fast=False` 以禁用此行为。 + + + +## 图像处理器 + +图像处理器用于处理视觉输入。它继承自 [`~image_processing_utils.ImageProcessingMixin`] 基类。 + +要使用它,需要创建一个与你使用的模型关联的图像处理器。例如,如果你使用 [ViT](model_doc/vit) 进行图像分类,可以创建一个默认的 [`ViTImageProcessor`]: + +```py +>>> from transformers import ViTImageProcessor + +>>> vit_extractor = ViTImageProcessor() +>>> print(vit_extractor) +ViTImageProcessor { + "do_normalize": true, + "do_resize": true, + "image_processor_type": "ViTImageProcessor", + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "resample": 2, + "size": 224 +} +``` + + + +如果您不需要进行任何自定义,只需使用 `from_pretrained` 方法加载模型的默认图像处理器参数。 + + + +修改任何 [`ViTImageProcessor`] 参数以创建自定义图像处理器: + +```py +>>> from transformers import ViTImageProcessor + +>>> my_vit_extractor = ViTImageProcessor(resample="PIL.Image.BOX", do_normalize=False, image_mean=[0.3, 0.3, 0.3]) +>>> print(my_vit_extractor) +ViTImageProcessor { + "do_normalize": false, + "do_resize": true, + "image_processor_type": "ViTImageProcessor", + "image_mean": [ + 0.3, + 0.3, + 0.3 + ], + "image_std": [ + 0.5, + 0.5, + 0.5 + ], + "resample": "PIL.Image.BOX", + "size": 224 +} +``` + +## 特征提取器 + +特征提取器用于处理音频输入。它继承自 [`~feature_extraction_utils.FeatureExtractionMixin`] 基类,亦可继承 [`SequenceFeatureExtractor`] 类来处理音频输入。 + +要使用它,创建一个与你使用的模型关联的特征提取器。例如,如果你使用 [Wav2Vec2](model_doc/wav2vec2) 进行音频分类,可以创建一个默认的 [`Wav2Vec2FeatureExtractor`]: + +```py +>>> from transformers import Wav2Vec2FeatureExtractor + +>>> w2v2_extractor = Wav2Vec2FeatureExtractor() +>>> print(w2v2_extractor) +Wav2Vec2FeatureExtractor { + "do_normalize": true, + "feature_extractor_type": "Wav2Vec2FeatureExtractor", + "feature_size": 1, + "padding_side": "right", + "padding_value": 0.0, + "return_attention_mask": false, + "sampling_rate": 16000 +} +``` + + + +如果您不需要进行任何自定义,只需使用 `from_pretrained` 方法加载模型的默认特征提取器参数。 + + + +修改任何 [`Wav2Vec2FeatureExtractor`] 参数以创建自定义特征提取器: + +```py +>>> from transformers import Wav2Vec2FeatureExtractor + +>>> w2v2_extractor = Wav2Vec2FeatureExtractor(sampling_rate=8000, do_normalize=False) +>>> print(w2v2_extractor) +Wav2Vec2FeatureExtractor { + "do_normalize": false, + "feature_extractor_type": "Wav2Vec2FeatureExtractor", + "feature_size": 1, + "padding_side": "right", + "padding_value": 0.0, + "return_attention_mask": false, + "sampling_rate": 8000 +} +``` + + +## 处理器 + +对于支持多模式任务的模型,🤗 Transformers 提供了一个处理器类,可以方便地将特征提取器和分词器等处理类包装到单个对象中。例如,让我们使用 [`Wav2Vec2Processor`] 来执行自动语音识别任务 (ASR)。 ASR 将音频转录为文本,因此您将需要一个特征提取器和一个分词器。 + +创建一个特征提取器来处理音频输入: + +```py +>>> from transformers import Wav2Vec2FeatureExtractor + +>>> feature_extractor = Wav2Vec2FeatureExtractor(padding_value=1.0, do_normalize=True) +``` + +创建一个分词器来处理文本输入: + +```py +>>> from transformers import Wav2Vec2CTCTokenizer + +>>> tokenizer = Wav2Vec2CTCTokenizer(vocab_file="my_vocab_file.txt") +``` + +将特征提取器和分词器合并到 [`Wav2Vec2Processor`] 中: + +```py +>>> from transformers import Wav2Vec2Processor + +>>> processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) +``` + +通过两个基类 - 配置类和模型类 - 以及一个附加的预处理类(分词器、图像处理器、特征提取器或处理器),你可以创建 🤗 Transformers 支持的任何模型。 每个基类都是可配置的,允许你使用所需的特定属性。 你可以轻松设置模型进行训练或修改现有的预训练模型进行微调。 From ede051f1b85a33d2e0576b48042a58dc5332ed70 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:55:14 +0200 Subject: [PATCH 19/38] Fix key dtype in GPTJ and CodeGen (#26836) * fix key dtype in gptj and codegen * delay the key cast to a later point * fix --- src/transformers/models/codegen/modeling_codegen.py | 4 +++- src/transformers/models/gptj/modeling_gptj.py | 4 +++- src/transformers/models/mistral/modeling_mistral.py | 1 + 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 172a45544bac0d..05699ef15c722f 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -224,7 +224,9 @@ def forward( value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = (key, value) + # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32. + # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38 + present = (key.to(hidden_states.dtype), value) else: present = None diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 6b5607f235b1a6..acdbb8c49211d5 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -249,7 +249,9 @@ def forward( value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = (key, value) + # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. + # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128 + present = (key.to(hidden_states.dtype), value) else: present = None diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f735c471268136..53667b6a82c327 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -306,6 +306,7 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, From cc7803c0a6194ce795ab903979dde9216c82f5bc Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 24 Oct 2023 17:02:40 +0800 Subject: [PATCH 20/38] Register ModelOutput as supported torch pytree nodes (#26618) * Register ModelOutput as supported torch pytree nodes * Test ModelOutput as supported torch pytree nodes * Update type hints for pytree unflatten functions --- src/transformers/utils/generic.py | 27 +++++++++++++++++++++------ tests/utils/test_model_output.py | 16 +++++++++------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index dc9ca4b51d0f18..34dac8bea70cfc 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -22,7 +22,7 @@ from contextlib import ExitStack, contextmanager from dataclasses import fields, is_dataclass from enum import Enum -from typing import Any, ContextManager, List, Tuple +from typing import Any, ContextManager, Iterable, List, Tuple import numpy as np @@ -306,12 +306,10 @@ def __init_subclass__(cls) -> None: `static_graph=True` with modules that output `ModelOutput` subclasses. """ if is_torch_available(): - import torch.utils._pytree - - torch.utils._pytree._register_pytree_node( + _torch_pytree._register_pytree_node( cls, - torch.utils._pytree._dict_flatten, - lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), + _model_output_flatten, + _model_output_unflatten, ) def __init__(self, *args, **kwargs): @@ -430,6 +428,23 @@ def to_tuple(self) -> Tuple[Any]: return tuple(self[k] for k in self.keys()) +if is_torch_available(): + import torch.utils._pytree as _torch_pytree + + def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]: + return list(output.values()), (type(output), list(output.keys())) + + def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput: + output_type, keys = context + return output_type(**dict(zip(keys, values))) + + _torch_pytree._register_pytree_node( + ModelOutput, + _model_output_flatten, + _model_output_unflatten, + ) + + class ExplicitEnum(str, Enum): """ Enum with more explicit error message for missing values. diff --git a/tests/utils/test_model_output.py b/tests/utils/test_model_output.py index abfc5427cf11ae..cabdc5fc2d6b0f 100644 --- a/tests/utils/test_model_output.py +++ b/tests/utils/test_model_output.py @@ -126,22 +126,24 @@ def test_instantiate_from_iterator(self): def test_torch_pytree(self): # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) # this is important for DistributedDataParallel gradient synchronization with static_graph=True - import torch - import torch.utils._pytree + import torch.utils._pytree as pytree + + x = ModelOutput({"a": 1.0, "c": 2.0}) + self.assertFalse(pytree._is_leaf(x)) x = ModelOutputTest(a=1.0, c=2.0) - self.assertFalse(torch.utils._pytree._is_leaf(x)) + self.assertFalse(pytree._is_leaf(x)) expected_flat_outs = [1.0, 2.0] - expected_tree_spec = torch.utils._pytree.TreeSpec( - ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()] + expected_tree_spec = pytree.TreeSpec( + ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()] ) - actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) + actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x) self.assertEqual(expected_flat_outs, actual_flat_outs) self.assertEqual(expected_tree_spec, actual_tree_spec) - unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) + unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) self.assertEqual(x, unflattened_x) From fc142bd775ae4639f80a8b0085a5df33bd2853ce Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 24 Oct 2023 11:08:17 +0200 Subject: [PATCH 21/38] Add `default_to_square_for_size` to `CLIPImageProcessor` (#26965) * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh --- .../models/bit/image_processing_bit.py | 17 +++++++++++++---- .../models/clip/image_processing_clip.py | 17 +++++++++++++---- .../image_processing_mobilenet_v1.py | 17 +++++++++++++---- .../image_processing_mobilenet_v2.py | 17 +++++++++++++---- .../mobilevit/image_processing_mobilevit.py | 17 +++++++++++++---- .../vit_hybrid/image_processing_vit_hybrid.py | 17 +++++++++++++---- 6 files changed, 78 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/bit/image_processing_bit.py b/src/transformers/models/bit/image_processing_bit.py index 1bd9d1095d7528..235f55afad7884 100644 --- a/src/transformers/models/bit/image_processing_bit.py +++ b/src/transformers/models/bit/image_processing_bit.py @@ -84,6 +84,10 @@ class BitImageProcessor(BaseImageProcessor): Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. + use_square_size (`bool`, *optional*, defaults to `False`): + The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the + `size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not. + Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`. """ model_input_names = ["pixel_values"] @@ -101,11 +105,12 @@ def __init__( image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, + use_square_size: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 224} - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=use_square_size) crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") @@ -120,6 +125,7 @@ def __init__( self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_convert_rgb = do_convert_rgb + self.use_square_size = use_square_size # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize def resize( @@ -147,11 +153,14 @@ def resize( input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) if "shortest_edge" not in size: raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") output_size = get_resize_output_image_size( - image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + image, + size=size["shortest_edge"], + default_to_square=self.use_square_size, + input_data_format=input_data_format, ) return resize( image, @@ -234,7 +243,7 @@ def preprocess( """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size - size = get_size_dict(size, param_name="size", default_to_square=False) + size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py index 811d69afce7072..df9628a66280b7 100644 --- a/src/transformers/models/clip/image_processing_clip.py +++ b/src/transformers/models/clip/image_processing_clip.py @@ -84,6 +84,10 @@ class CLIPImageProcessor(BaseImageProcessor): Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. + use_square_size (`bool`, *optional*, defaults to `False`): + The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the + `size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not. + Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`. """ model_input_names = ["pixel_values"] @@ -101,11 +105,12 @@ def __init__( image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, + use_square_size: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 224} - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=use_square_size) crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") @@ -120,6 +125,7 @@ def __init__( self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_convert_rgb = do_convert_rgb + self.use_square_size = use_square_size def resize( self, @@ -146,11 +152,14 @@ def resize( input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) if "shortest_edge" not in size: raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") output_size = get_resize_output_image_size( - image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + image, + size=size["shortest_edge"], + default_to_square=self.use_square_size, + input_data_format=input_data_format, ) return resize( image, @@ -233,7 +242,7 @@ def preprocess( """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size - size = get_size_dict(size, param_name="size", default_to_square=False) + size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size diff --git a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py index c9b015c5c01fb7..824e4be5e82e7a 100644 --- a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py @@ -79,6 +79,10 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + use_square_size (`bool`, *optional*, defaults to `False`): + The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the + `size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not. + Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`. """ model_input_names = ["pixel_values"] @@ -95,11 +99,12 @@ def __init__( do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, + use_square_size: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 256} - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=use_square_size) crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = get_size_dict(crop_size) self.do_resize = do_resize @@ -112,6 +117,7 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.use_square_size = use_square_size # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize def resize( @@ -139,11 +145,14 @@ def resize( input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) if "shortest_edge" not in size: raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") output_size = get_resize_output_image_size( - image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + image, + size=size["shortest_edge"], + default_to_square=self.use_square_size, + input_data_format=input_data_format, ) return resize( image, @@ -222,7 +231,7 @@ def preprocess( """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py index 9b015c88bf1d69..8f19869e47f9ad 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py @@ -83,6 +83,10 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + use_square_size (`bool`, *optional*, defaults to `False`): + The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the + `size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not. + Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`. """ model_input_names = ["pixel_values"] @@ -99,11 +103,12 @@ def __init__( do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, + use_square_size: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 256} - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=use_square_size) crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = get_size_dict(crop_size, param_name="crop_size") self.do_resize = do_resize @@ -116,6 +121,7 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.use_square_size = use_square_size # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize def resize( @@ -143,11 +149,14 @@ def resize( input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) if "shortest_edge" not in size: raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") output_size = get_resize_output_image_size( - image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + image, + size=size["shortest_edge"], + default_to_square=self.use_square_size, + input_data_format=input_data_format, ) return resize( image, @@ -226,7 +235,7 @@ def preprocess( """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index 0f3a422b30a07f..ecd8cac9e9506e 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -78,6 +78,10 @@ class MobileViTImageProcessor(BaseImageProcessor): do_flip_channel_order (`bool`, *optional*, defaults to `True`): Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order` parameter in the `preprocess` method. + use_square_size (`bool`, *optional*, defaults to `False`): + The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the + `size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not. + Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`. """ model_input_names = ["pixel_values"] @@ -92,11 +96,12 @@ def __init__( do_center_crop: bool = True, crop_size: Dict[str, int] = None, do_flip_channel_order: bool = True, + use_square_size: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 224} - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=use_square_size) crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} crop_size = get_size_dict(crop_size, param_name="crop_size") @@ -108,6 +113,7 @@ def __init__( self.do_center_crop = do_center_crop self.crop_size = crop_size self.do_flip_channel_order = do_flip_channel_order + self.use_square_size = use_square_size # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR def resize( @@ -135,11 +141,14 @@ def resize( input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) if "shortest_edge" not in size: raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") output_size = get_resize_output_image_size( - image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + image, + size=size["shortest_edge"], + default_to_square=self.use_square_size, + input_data_format=input_data_format, ) return resize( image, @@ -237,7 +246,7 @@ def preprocess( ) size = size if size is not None else self.size - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict(crop_size, param_name="crop_size") diff --git a/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py b/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py index ce6e3ffafe889b..81a07a9d79575f 100644 --- a/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py @@ -84,6 +84,10 @@ class ViTHybridImageProcessor(BaseImageProcessor): Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. + use_square_size (`bool`, *optional*, defaults to `False`): + The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the + `size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not. + Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`. """ model_input_names = ["pixel_values"] @@ -101,11 +105,12 @@ def __init__( image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, + use_square_size: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 224} - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=use_square_size) crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") @@ -120,6 +125,7 @@ def __init__( self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_convert_rgb = do_convert_rgb + self.use_square_size = use_square_size # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize def resize( @@ -147,11 +153,14 @@ def resize( input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ - size = get_size_dict(size, default_to_square=False) + size = get_size_dict(size, default_to_square=self.use_square_size) if "shortest_edge" not in size: raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") output_size = get_resize_output_image_size( - image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + image, + size=size["shortest_edge"], + default_to_square=self.use_square_size, + input_data_format=input_data_format, ) return resize( image, @@ -234,7 +243,7 @@ def preprocess( """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size - size = get_size_dict(size, param_name="size", default_to_square=False) + size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size From 576e2823a397942421e1724e79f51a12122ef49e Mon Sep 17 00:00:00 2001 From: JP Date: Tue, 24 Oct 2023 03:02:06 -0700 Subject: [PATCH 22/38] Add descriptive docstring to WhisperTimeStampLogitsProcessor (#25642) * adding in logit examples for Whisper processor * adding in updated logits processor for Whisper * adding in cleaned version of logits processor for Whisper * adding docstrings for whisper processor * making sure the formatting is correct * adding logits after doc builder * Update src/transformers/generation/logits_process.py Adding in suggested fix to the LogitProcessor description. Co-authored-by: Joao Gante * Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante * Update src/transformers/generation/logits_process.py Removing tip per suggestion. Co-authored-by: Joao Gante * Update src/transformers/generation/logits_process.py Removing redundant code per suggestion. Co-authored-by: Joao Gante * adding in revised version * adding in version with timestamp examples * Update src/transformers/generation/logits_process.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * enhanced paragraph on behavior of processor * fixing doc quality issue * removing the word poem from example * adding in updated docstring * adding in new version of file after doc-builder --------- Co-authored-by: Joao Gante Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/logits_process.py | 39 ++++++++++++++++++- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 3ab689b55a767d..ed69a163778dd1 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1457,8 +1457,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class WhisperTimeStampLogitsProcessor(LogitsProcessor): r""" - Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log - probs to `inf` so that they are sampled at their corresponding index. + + [`LogitsProcessor`] that modifies the logits for the generation of timestamps in the transcription. When the input + tokens are at a specific threshold, the processor sets the scores to negative infinity. The processor makes sure + that timestamp tokens appear in pairs, by masking out the logits that would break this pairing pattern. This is + done to maintain the consistency and structure of generated timestamps. It also ensures that when the predicted + probability of sampling any of the timestamp token is greater than any individual non-timestamp token, those + non-timestamp logits are set to negative infinity. This is done to ensure the generation of timestamps over other + potential tokens. + See [the paper](https://arxiv.org/abs/2212.04356) for more information. @@ -1472,6 +1479,34 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): max_initial_timestamp_index (`int`, *optional*, defaults to 1): Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting timestamps that are too far in the future. + + Examples: + ``` python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration,GenerationConfig + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[3]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> #Displaying timestamps + >>> generated_ids = model.generate(inputs=input_features, return_timestamps=True) + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + >>> print("Transcription:", transcription) + Transcription: <|startoftranscript|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can<|6.44|><|6.44|> discover in it but little of rocky Ithaca.<|9.44|><|endoftext|> + + + >>> #No timestamps & change EOS: + >>> #This allows the user to select a specific token to terminate the sequence on, in this case it's the word "can"(460) + >>> model.generation_config.eos_token_id = 460 + >>> generated_ids = model.generate(inputs=input_features,return_timestamps=False) + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> print("Transcription:", transcription) + Transcription: He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can + ``` """ def __init__(self, generate_config): # support for the kwargs From e2d6d5ce57373a453e8c6fb4293cf08c17b061e1 Mon Sep 17 00:00:00 2001 From: Michal Jamroz Date: Tue, 24 Oct 2023 14:32:03 +0200 Subject: [PATCH 23/38] Normalize only if needed (#26049) * Normalize only if needed * Update examples/pytorch/image-classification/run_image_classification.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * if else in one line * within block * one more place, sorry for mess * import order * Update examples/pytorch/image-classification/run_image_classification.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/pytorch/image-classification/run_image_classification_no_trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../image-classification/run_image_classification.py | 7 ++++++- .../run_image_classification_no_trainer.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) mode change 100644 => 100755 examples/pytorch/image-classification/run_image_classification.py diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py old mode 100644 new mode 100755 index 27a81f00948281..d00929754148c4 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -28,6 +28,7 @@ from torchvision.transforms import ( CenterCrop, Compose, + Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, @@ -325,7 +326,11 @@ def compute_metrics(p): size = image_processor.size["shortest_edge"] else: size = (image_processor.size["height"], image_processor.size["width"]) - normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + normalize = ( + Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") + else Lambda(lambda x: x) + ) _train_transforms = Compose( [ RandomResizedCrop(size), diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py index 3e38a3e79a715d..52b5fabd8947b0 100644 --- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py +++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py @@ -32,6 +32,7 @@ from torchvision.transforms import ( CenterCrop, Compose, + Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, @@ -331,7 +332,11 @@ def main(): size = image_processor.size["shortest_edge"] else: size = (image_processor.size["height"], image_processor.size["width"]) - normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + normalize = ( + Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") + else Lambda(lambda x: x) + ) train_transforms = Compose( [ RandomResizedCrop(size), From 7bde5d634f535206e0bf4e279b32d65d1f7568f7 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 24 Oct 2023 14:33:05 +0200 Subject: [PATCH 24/38] [`TFxxxxForSequenceClassifciation`] Fix the eager mode after #25085 (#25751) * TODOS * Switch .shape -> shape_list --------- Co-authored-by: Matt --- src/transformers/models/gptj/modeling_tf_gptj.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index f215adaaac0055..f5080f674c3e1b 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -870,7 +870,11 @@ def call( tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1) - 1 ) - sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1) + sequence_lengths = tf.where( + sequence_lengths >= 0, + sequence_lengths, + tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1, + ) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) else: sequence_lengths = -1 From cb0c68069d118c7ca23c70dcd22ac45aa8313d50 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 24 Oct 2023 13:40:16 +0100 Subject: [PATCH 25/38] Safe import of rgb_to_id from FE modules (#27037) Safe import from FE modules --- .../feature_extraction_conditional_detr.py | 10 ++++++++++ .../feature_extraction_deformable_detr.py | 10 ++++++++++ .../models/detr/feature_extraction_detr.py | 10 ++++++++++ .../models/yolos/feature_extraction_yolos.py | 10 ++++++++++ 4 files changed, 40 insertions(+) diff --git a/src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py b/src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py index 2af959e8a991f3..bfdec373f865c5 100644 --- a/src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py +++ b/src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py @@ -16,6 +16,7 @@ import warnings +from ...image_transforms import rgb_to_id as _rgb_to_id from ...utils import logging from .image_processing_conditional_detr import ConditionalDetrImageProcessor @@ -23,6 +24,15 @@ logger = logging.get_logger(__name__) +def rgb_to_id(x): + warnings.warn( + "rgb_to_id has moved and will not be importable from this module from v5. " + "Please import from transformers.image_transforms instead.", + FutureWarning, + ) + return _rgb_to_id(x) + + class ConditionalDetrFeatureExtractor(ConditionalDetrImageProcessor): def __init__(self, *args, **kwargs) -> None: warnings.warn( diff --git a/src/transformers/models/deformable_detr/feature_extraction_deformable_detr.py b/src/transformers/models/deformable_detr/feature_extraction_deformable_detr.py index 6f1ca003a00734..f04743e91ceefe 100644 --- a/src/transformers/models/deformable_detr/feature_extraction_deformable_detr.py +++ b/src/transformers/models/deformable_detr/feature_extraction_deformable_detr.py @@ -16,6 +16,7 @@ import warnings +from ...image_transforms import rgb_to_id as _rgb_to_id from ...utils import logging from .image_processing_deformable_detr import DeformableDetrImageProcessor @@ -23,6 +24,15 @@ logger = logging.get_logger(__name__) +def rgb_to_id(x): + warnings.warn( + "rgb_to_id has moved and will not be importable from this module from v5. " + "Please import from transformers.image_transforms instead.", + FutureWarning, + ) + return _rgb_to_id(x) + + class DeformableDetrFeatureExtractor(DeformableDetrImageProcessor): def __init__(self, *args, **kwargs) -> None: warnings.warn( diff --git a/src/transformers/models/detr/feature_extraction_detr.py b/src/transformers/models/detr/feature_extraction_detr.py index b94cf9ff804132..6ea33666466f9a 100644 --- a/src/transformers/models/detr/feature_extraction_detr.py +++ b/src/transformers/models/detr/feature_extraction_detr.py @@ -16,6 +16,7 @@ import warnings +from ...image_transforms import rgb_to_id as _rgb_to_id from ...utils import logging from .image_processing_detr import DetrImageProcessor @@ -23,6 +24,15 @@ logger = logging.get_logger(__name__) +def rgb_to_id(x): + warnings.warn( + "rgb_to_id has moved and will not be importable from this module from v5. " + "Please import from transformers.image_transforms instead.", + FutureWarning, + ) + return _rgb_to_id(x) + + class DetrFeatureExtractor(DetrImageProcessor): def __init__(self, *args, **kwargs) -> None: warnings.warn( diff --git a/src/transformers/models/yolos/feature_extraction_yolos.py b/src/transformers/models/yolos/feature_extraction_yolos.py index a19c87c503e571..5696ee65bbdfa2 100644 --- a/src/transformers/models/yolos/feature_extraction_yolos.py +++ b/src/transformers/models/yolos/feature_extraction_yolos.py @@ -16,6 +16,7 @@ import warnings +from ...image_transforms import rgb_to_id as _rgb_to_id from ...utils import logging from .image_processing_yolos import YolosImageProcessor @@ -23,6 +24,15 @@ logger = logging.get_logger(__name__) +def rgb_to_id(x): + warnings.warn( + "rgb_to_id has moved and will not be importable from this module from v5. " + "Please import from transformers.image_transforms instead.", + FutureWarning, + ) + return _rgb_to_id(x) + + class YolosFeatureExtractor(YolosImageProcessor): def __init__(self, *args, **kwargs) -> None: warnings.warn( From b18e31407c1e0933a51c0e53f7777b6a47d034f8 Mon Sep 17 00:00:00 2001 From: Leandro von Werra Date: Tue, 24 Oct 2023 14:56:00 +0200 Subject: [PATCH 26/38] add info on TRL docs (#27024) * add info on TRL docs * add TRL link * tweak text * tweak text --- docs/source/en/main_classes/trainer.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/en/main_classes/trainer.md b/docs/source/en/main_classes/trainer.md index f433b820af8b56..46341804ce756f 100644 --- a/docs/source/en/main_classes/trainer.md +++ b/docs/source/en/main_classes/trainer.md @@ -18,6 +18,12 @@ rendered properly in your Markdown viewer. The [`Trainer`] class provides an API for feature-complete training in PyTorch for most standard use cases. It's used in most of the [example scripts](https://github.com/huggingface/transformers/tree/main/examples). + + +If you're looking to fine-tune a language model like Llama-2 or Mistral on a text dataset using autoregressive techniques, consider using [`trl`](https://github.com/huggingface/trl)'s [`~trl.SFTTrainer`]. The [`~trl.SFTTrainer`] wraps the [`Trainer`] and is specially optimized for this particular task and supports sequence packing, LoRA, quantization, and DeepSpeed for efficient scaling to any model size. On the other hand, the [`Trainer`] is a more versatile option, suitable for a broader spectrum of tasks. + + + Before instantiating your [`Trainer`], create a [`TrainingArguments`] to access all the points of customization during training. The API supports distributed training on multiple GPUs/TPUs, mixed precision through [NVIDIA Apex](https://github.com/NVIDIA/apex) and Native AMP for PyTorch. From 41496b95da5d34368a13538548c85527377972a6 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:10:23 +0200 Subject: [PATCH 27/38] Add fuyu device map (#26949) * add _no_split_modules * style * fix _no_split_modules * add doc --- src/transformers/modeling_utils.py | 36 +++++++++++++++---- src/transformers/models/fuyu/modeling_fuyu.py | 1 + 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0317695f2096de..cceb064bb44243 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1507,6 +1507,35 @@ def _tie_or_clone_weights(self, output_embeddings, input_embeddings): if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + if self._no_split_modules is None: + raise ValueError( + f"{self.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + _no_split_modules = set(self._no_split_modules) + for module in self.modules(): + if isinstance(module, PreTrainedModel): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + return list(_no_split_modules) + def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None ) -> nn.Embedding: @@ -3226,12 +3255,7 @@ def from_pretrained( elif load_in_8bit: target_dtype = torch.int8 - if model._no_split_modules is None: - raise ValueError( - f"{model.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " - "class needs to implement the `_no_split_modules` attribute." - ) - no_split_modules = model._no_split_modules + no_split_modules = model._get_no_split_modules(device_map) if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: raise ValueError( "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index b14b1b0b871d3a..03312420ca6e6d 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -262,6 +262,7 @@ def forward( inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_patches is not None and past_key_values is None: patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype)) + patch_embeddings = patch_embeddings.to(inputs_embeds.device) inputs_embeds = self.gather_continuous_embeddings( word_embeddings=inputs_embeds, continuous_embeddings=patch_embeddings, From 9da451713d5feac451505c09337ca07ff6d0dba0 Mon Sep 17 00:00:00 2001 From: Alex McKinney <44398246+vvvm23@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:49:26 +0100 Subject: [PATCH 28/38] Device agnostic testing (#25870) * adds agnostic decorators and availability fns * renaming decorators and fixing imports * updating some representative example tests bloom, opt, and reformer for now * wip device agnostic functions * lru cache to device checking functions * adds `TRANSFORMERS_TEST_DEVICE_SPEC` if present, imports the target file and updates device to function mappings * comments `TRANSFORMERS_TEST_DEVICE_SPEC` code * extra checks on device name * `make style; make quality` * updates default functions for agnostic calls * applies suggestions from review * adds `is_torch_available` guard * Add spec file to docs, rename function dispatch names to backend_* * add backend import to docs example for spec file * change instances of to * Move register backend to before device check as per @statelesshz changes * make style * make opt test require fp16 to run --------- Co-authored-by: arsalanu Co-authored-by: arsalanu --- docs/source/en/testing.md | 19 +++ src/transformers/testing_utils.py | 127 ++++++++++++++++-- src/transformers/utils/__init__.py | 2 + src/transformers/utils/import_utils.py | 39 ++++++ tests/models/bloom/test_modeling_bloom.py | 10 +- tests/models/codegen/test_modeling_codegen.py | 5 +- tests/models/opt/test_modeling_opt.py | 6 +- .../models/reformer/test_modeling_reformer.py | 5 +- 8 files changed, 188 insertions(+), 25 deletions(-) diff --git a/docs/source/en/testing.md b/docs/source/en/testing.md index 25df6e295a130b..d704c84a628f93 100644 --- a/docs/source/en/testing.md +++ b/docs/source/en/testing.md @@ -525,6 +525,25 @@ Certain devices will require an additional import after importing `torch` for th ```bash TRANSFORMERS_TEST_BACKEND="torch_npu" pytest tests/utils/test_logging.py ``` +Alternative backends may also require the replacement of device-specific functions. For example `torch.cuda.manual_seed` may need to be replaced with a device-specific seed setter like `torch.npu.manual_seed` to correctly set a random seed on the device. To specify a new backend with backend-specific device functions when running the test suite, create a Python device specification file in the format: + +``` +import torch +import torch_npu +# !! Further additional imports can be added here !! + +# Specify the device name (eg. 'cuda', 'cpu', 'npu') +DEVICE_NAME = 'npu' + +# Specify device-specific backends to dispatch to. +# If not specified, will fallback to 'default' in 'testing_utils.py` +MANUAL_SEED_FN = torch.npu.manual_seed +EMPTY_CACHE_FN = torch.npu.empty_cache +DEVICE_COUNT_FN = torch.npu.device_count +``` +This format also allows for specification of any additional imports required. To use this file to replace equivalent methods in the test suite, set the environment variable `TRANSFORMERS_TEST_DEVICE_SPEC` to the path of the spec file. + +Currently, only `MANUAL_SEED_FN`, `EMPTY_CACHE_FN` and `DEVICE_COUNT_FN` are supported for device-specific dispatch. ### Distributed training diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 50f50c83c4e001..a5cac8ddfe9692 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -32,7 +32,7 @@ from collections.abc import Mapping from io import StringIO from pathlib import Path -from typing import Iterable, Iterator, List, Optional, Union +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union from unittest import mock import huggingface_hub @@ -98,8 +98,10 @@ is_timm_available, is_tokenizers_available, is_torch_available, + is_torch_bf16_available_on_device, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_fp16_available_on_device, is_torch_neuroncore_available, is_torch_npu_available, is_torch_tensorrt_fx_available, @@ -713,6 +715,16 @@ def require_torch_multi_xpu(test_case): # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode import torch + if "TRANSFORMERS_TEST_BACKEND" in os.environ: + backend = os.environ["TRANSFORMERS_TEST_BACKEND"] + try: + _ = importlib.import_module(backend) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its" + f" traceback):\n{e}" + ) from e + if "TRANSFORMERS_TEST_DEVICE" in os.environ: torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"] try: @@ -730,17 +742,6 @@ def require_torch_multi_xpu(test_case): torch_device = "xpu" else: torch_device = "cpu" - - if "TRANSFORMERS_TEST_BACKEND" in os.environ: - backend = os.environ["TRANSFORMERS_TEST_BACKEND"] - try: - _ = importlib.import_module(backend) - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its" - f" traceback):\n{e}" - ) from e - else: torch_device = None @@ -770,6 +771,25 @@ def require_torch_gpu(test_case): return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) +def require_torch_accelerator(test_case): + """Decorator marking a test that requires an accessible accelerator and PyTorch.""" + return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case) + + +def require_torch_fp16(test_case): + """Decorator marking a test that requires a device that supports fp16""" + return unittest.skipUnless( + is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support" + )(test_case) + + +def require_torch_bf16(test_case): + """Decorator marking a test that requires a device that supports bf16""" + return unittest.skipUnless( + is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support" + )(test_case) + + def require_torch_bf16_gpu(test_case): """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0""" return unittest.skipUnless( @@ -2176,3 +2196,86 @@ def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None: for test in finder.find(module, module.__name__): if test.examples: # skip empty doctests and cuda yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test) + + +def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): + if device not in dispatch_table: + return dispatch_table["default"](*args, **kwargs) + + fn = dispatch_table[device] + + # Some device agnostic functions return values. Need to guard against `None` + # instead at user level. + if fn is None: + return None + return fn(*args, **kwargs) + + +if is_torch_available(): + # Mappings from device names to callable functions to support device agnostic + # testing. + BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} + BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None} + BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1} + + +def backend_manual_seed(device: str, seed: int): + return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) + + +def backend_empty_cache(device: str): + return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) + + +def backend_device_count(device: str): + return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) + + +if is_torch_available(): + # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries + # into device to function mappings. + if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ: + device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"] + if not Path(device_spec_path).is_file(): + raise ValueError( + f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}" + ) + + # Try to strip extension for later import – also verifies we are importing a + # python file. + try: + import_name = device_spec_path[: device_spec_path.index(".py")] + except ValueError as e: + raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e + + device_spec_module = importlib.import_module(import_name) + + # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early. + try: + device_name = device_spec_module.DEVICE_NAME + except AttributeError as e: + raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e + + if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name: + msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" + msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name." + raise ValueError(msg) + + torch_device = device_name + + def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): + try: + # Try to import the function directly + spec_fn = getattr(device_spec_module, attribute_name) + device_fn_dict[torch_device] = spec_fn + except AttributeError as e: + # If the function doesn't exist, and there is no default, throw an error + if "default" not in device_fn_dict: + raise AttributeError( + f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." + ) from e + + # Add one entry here for each `BACKEND_*` dictionary. + update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") + update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") + update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index b66f0db3892097..25dfa53d8d7e23 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -166,10 +166,12 @@ is_tokenizers_available, is_torch_available, is_torch_bf16_available, + is_torch_bf16_available_on_device, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, is_torch_compile_available, is_torch_cuda_available, + is_torch_fp16_available_on_device, is_torch_fx_available, is_torch_fx_proxy, is_torch_mps_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c135034d02bb71..48e4e7023603bb 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -351,6 +351,45 @@ def is_torch_bf16_available(): return is_torch_bf16_gpu_available() +@lru_cache() +def is_torch_fp16_available_on_device(device): + if not is_torch_available(): + return False + + import torch + + try: + x = torch.zeros(2, 2, dtype=torch.float16).to(device) + _ = x @ x + except: # noqa: E722 + # TODO: more precise exception matching, if possible. + # most backends should return `RuntimeError` however this is not guaranteed. + return False + + return True + + +@lru_cache() +def is_torch_bf16_available_on_device(device): + if not is_torch_available(): + return False + + import torch + + if device == "cuda": + return is_torch_bf16_gpu_available() + + try: + x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device) + _ = x @ x + except: # noqa: E722 + # TODO: more precise exception matching, if possible. + # most backends should return `RuntimeError` however this is not guaranteed. + return False + + return True + + def is_torch_tf32_available(): if not is_torch_available(): return False diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index c05d45ebecc2d7..5d090c5785fb1d 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -18,7 +18,7 @@ import unittest from transformers import BloomConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -401,7 +401,7 @@ def test_model_from_pretrained(self): self.assertIsNotNone(model) @slow - @require_torch_gpu + @require_torch_accelerator def test_simple_generation(self): # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 @@ -440,7 +440,7 @@ def test_simple_generation(self): self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) @slow - @require_torch_gpu + @require_torch_accelerator def test_batch_generation(self): path_560m = "bigscience/bloom-560m" model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device) @@ -460,7 +460,7 @@ def test_batch_generation(self): ) @slow - @require_torch_gpu + @require_torch_accelerator def test_batch_generation_padd(self): path_560m = "bigscience/bloom-560m" model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device) @@ -489,7 +489,7 @@ def test_batch_generation_padd(self): ) @slow - @require_torch_gpu + @require_torch_accelerator def test_batch_generated_text(self): path_560m = "bigscience/bloom-560m" diff --git a/tests/models/codegen/test_modeling_codegen.py b/tests/models/codegen/test_modeling_codegen.py index 34a32caa7ff8c0..e042ccac71dab0 100644 --- a/tests/models/codegen/test_modeling_codegen.py +++ b/tests/models/codegen/test_modeling_codegen.py @@ -19,7 +19,7 @@ from transformers import CodeGenConfig, is_torch_available from transformers.file_utils import cached_property -from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.testing_utils import backend_manual_seed, is_flaky, require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -498,8 +498,7 @@ def test_codegen_sample(self): model.to(torch_device) torch.manual_seed(0) - if torch_device == "cuda": - torch.cuda.manual_seed(0) + backend_manual_seed(torch_device, 0) tokenized = tokenizer("def hello_world():", return_tensors="pt", return_token_type_ids=True) input_ids = tokenized.input_ids.to(torch_device) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 232a29d4f50e6b..18c0c9a7efeeaf 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -22,7 +22,7 @@ import timeout_decorator # noqa from transformers import OPTConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_fp16, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -286,13 +286,13 @@ def test_inputs_embeds(self): with torch.no_grad(): model(**inputs)[0] + @require_torch_fp16 def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) model = OPTForCausalLM(config).eval().to(torch_device) - if torch_device == "cuda": - model.half() + model.half() model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py index c84f729633cc8a..11cd7e1a33b45a 100644 --- a/tests/models/reformer/test_modeling_reformer.py +++ b/tests/models/reformer/test_modeling_reformer.py @@ -20,6 +20,7 @@ require_sentencepiece, require_tokenizers, require_torch, + require_torch_fp16, require_torch_multi_gpu, slow, torch_device, @@ -563,12 +564,12 @@ def test_dropout_random_seed_is_changing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs) - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + @require_torch_fp16 def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + @require_torch_fp16 def test_reformer_model_fp16_generate(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) From 13ef14e18e96656b31265836693b2d358894879b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 19:05:37 +0200 Subject: [PATCH 29/38] Fix config silent copy in from_pretrained (#27043) * Fix config modeling utils * fix more * fix attn mask bug * Update src/transformers/modeling_utils.py --- src/transformers/modeling_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cceb064bb44243..6f866f989a6777 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3135,6 +3135,9 @@ def from_pretrained( with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) + # make sure we use the model's config since the __init__ call might have copied it + config = model.config + # Check first if we are `from_pt` if use_keep_in_fp32_modules: if is_accelerate_available(): @@ -3193,7 +3196,7 @@ def from_pretrained( importlib.metadata.version("bitsandbytes") ) >= version.parse("0.37.0") - model.config.quantization_config = quantization_config + config.quantization_config = quantization_config model.is_8bit_serializable = is_8bit_serializable if load_in_8bit and torch_dtype is None: @@ -3423,7 +3426,7 @@ def from_pretrained( if cls.main_input_name != "input_ids": raise RuntimeError("We can only quantize pure text model.") quantizer.quantize_model(model, quantization_config.tokenizer) - model.config.quantization_config = GPTQConfig.from_dict(quantizer.to_dict()) + config.quantization_config = GPTQConfig.from_dict(quantizer.to_dict()) model._is_quantized_training_enabled = True if quantization_method_from_config == QuantizationMethod.GPTQ: model = quantizer.post_init_model(model) From 9333bf0769561c048700377c2e0813221ab9d2c9 Mon Sep 17 00:00:00 2001 From: Maria Khalusova Date: Tue, 24 Oct 2023 13:10:06 -0400 Subject: [PATCH 30/38] [docs] Performance docs refactor p.2 (#26791) * initial edits * improvements for clarity and flow * improvements for clarity and flow, removed the repetead section * removed two docs that had no content * Revert "removed two docs that had no content" This reverts commit e98fa2fa0d8e171163f15cb8a04bdada1053543b. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * feedback addressed * more feedback addressed * feedback addressed --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/glossary.md | 33 ++ docs/source/en/perf_train_gpu_many.md | 533 ++++++++++++++------------ 2 files changed, 324 insertions(+), 242 deletions(-) diff --git a/docs/source/en/glossary.md b/docs/source/en/glossary.md index ea1e53ab598710..548a56b7ed2894 100644 --- a/docs/source/en/glossary.md +++ b/docs/source/en/glossary.md @@ -112,6 +112,12 @@ A type of layer in a neural network where the input matrix is multiplied element ## D +### DataParallel (DP) + +Parallelism technique for training on multiple GPUs where the same setup is replicated multiple times, with each instance +receiving a distinct data slice. The processing is done in parallel and all setups are synchronized at the end of each training step. +Learn more about how DataParallel works [here](perf_train_gpu_many#dataparallel-vs-distributeddataparallel). + ### decoder input IDs This input is specific to encoder-decoder models, and contains the input IDs that will be fed to the decoder. These @@ -340,6 +346,12 @@ A pipeline in 🤗 Transformers is an abstraction referring to a series of steps For more details, see [Pipelines for inference](https://huggingface.co/docs/transformers/pipeline_tutorial). +### PipelineParallel (PP) + +Parallelism technique in which the model is split up vertically (layer-level) across multiple GPUs, so that only one or +several layers of the model are placed on a single GPU. Each GPU processes in parallel different stages of the pipeline +and working on a small chunk of the batch. Learn more about how PipelineParallel works [here](perf_train_gpu_many#from-naive-model-parallelism-to-pipeline-parallelism). + ### pixel values A tensor of the numerical representations of an image that is passed to a model. The pixel values have a shape of [`batch_size`, `num_channels`, `height`, `width`], and are generated from an image processor. @@ -410,6 +422,10 @@ An example of a semi-supervised learning approach is "self-training", in which a Models that generate a new sequence from an input, like translation models, or summarization models (such as [Bart](model_doc/bart) or [T5](model_doc/t5)). +### Sharded DDP + +Another name for the foundational [ZeRO](#zero-redundancy-optimizer--zero-) concept as used by various other implementations of ZeRO. + ### stride In [convolution](#convolution) or [pooling](#pooling), the stride refers to the distance the kernel is moved over a matrix. A stride of 1 means the kernel is moved one pixel over at a time, and a stride of 2 means the kernel is moved two pixels over at a time. @@ -420,6 +436,14 @@ A form of model training that directly uses labeled data to correct and instruct ## T +### Tensor Parallelism (TP) + +Parallelism technique for training on multiple GPUs in which each tensor is split up into multiple chunks, so instead of +having the whole tensor reside on a single GPU, each shard of the tensor resides on its designated GPU. Shards gets +processed separately and in parallel on different GPUs and the results are synced at the end of the processing step. +This is what is sometimes called horizontal parallelism, as the splitting happens on horizontal level. +Learn more about Tensor Parallelism [here](perf_train_gpu_many#tensor-parallelism). + ### token A part of a sentence, usually a word, but can also be a subword (non-common words are often split in subwords) or a @@ -489,3 +513,12 @@ Self-attention based deep learning model architecture. ### unsupervised learning A form of model training in which data provided to the model is not labeled. Unsupervised learning techniques leverage statistical information of the data distribution to find patterns useful for the task at hand. + +## Z + +### Zero Redundancy Optimizer (ZeRO) + +Parallelism technique which performs sharding of the tensors somewhat similar to [TensorParallel](#tensorparallel--tp-), +except the whole tensor gets reconstructed in time for a forward or backward computation, therefore the model doesn't need +to be modified. This method also supports various offloading techniques to compensate for limited GPU memory. +Learn more about ZeRO [here](perf_train_gpu_many#zero-data-parallelism). \ No newline at end of file diff --git a/docs/source/en/perf_train_gpu_many.md b/docs/source/en/perf_train_gpu_many.md index fc93f763d8156e..ecabbcd06f3638 100644 --- a/docs/source/en/perf_train_gpu_many.md +++ b/docs/source/en/perf_train_gpu_many.md @@ -15,143 +15,154 @@ rendered properly in your Markdown viewer. # Efficient Training on Multiple GPUs -When training on a single GPU is too slow or the model weights don't fit in a single GPUs memory we use a multi-GPU setup. Switching from a single GPU to multiple requires some form of parallelism as the work needs to be distributed. There are several techniques to achieve parallism such as data, tensor, or pipeline parallism. However, there is no one solution to fit them all and which settings works best depends on the hardware you are running on. While the main concepts most likely will apply to any other framework, this article is focused on PyTorch-based implementations. +If training a model on a single GPU is too slow or if the model's weights do not fit in a single GPU's memory, transitioning +to a multi-GPU setup may be a viable option. Prior to making this transition, thoroughly explore all the strategies covered +in the [Methods and tools for efficient training on a single GPU](perf_train_gpu_one) as they are universally applicable +to model training on any number of GPUs. Once you have employed those strategies and found them insufficient for your +case on a single GPU, consider moving to multiple GPUs. + +Transitioning from a single GPU to multiple GPUs requires the introduction of some form of parallelism, as the workload +must be distributed across the resources. Multiple techniques can be employed to achieve parallelism, such as data +parallelism, tensor parallelism, and pipeline parallelism. It's important to note that there isn't a one-size-fits-all +solution, and the optimal settings depend on the specific hardware configuration you are using. + +This guide offers an in-depth overview of individual types of parallelism, as well as guidance on ways to combine +techniques and choosing an appropriate approach. For step-by-step tutorials on distributed training, please refer to +the [🤗 Accelerate documentation](https://huggingface.co/docs/accelerate/index). - Note: Most of the strategies introduced in the [single GPU section](perf_train_gpu_one) (such as mixed precision training or gradient accumulation) are generic and apply to training models in general so make sure to have a look at it before diving into the following sections such as multi-GPU or CPU training. +While the main concepts discussed in this guide are likely applicable across frameworks, here we focus on +PyTorch-based implementations. -We will first discuss in depth various 1D parallelism techniques and their pros and cons and then look at how they can be combined into 2D and 3D parallelism to enable an even faster training and to support even bigger models. Various other powerful alternative approaches will be presented. +Before diving deeper into the specifics of each technique, let's go over the rough decision process when training +large models on a large infrastructure. -## Concepts +## Scalability strategy -The following is the brief description of the main concepts that will be described later in depth in this document. +Begin by estimating how much vRAM is required to train your model. For models hosted on the 🤗 Hub, use our +[Model Memory Calculator](https://huggingface.co/spaces/hf-accelerate/model-memory-usage), which gives you +accurate calculations within a few percent margin. -1. **DataParallel (DP)** - the same setup is replicated multiple times, and each being fed a slice of the data. The processing is done in parallel and all setups are synchronized at the end of each training step. -2. **TensorParallel (TP)** - each tensor is split up into multiple chunks, so instead of having the whole tensor reside on a single gpu, each shard of the tensor resides on its designated gpu. During processing each shard gets processed separately and in parallel on different GPUs and the results are synced at the end of the step. This is what one may call horizontal parallelism, as the splitting happens on horizontal level. -3. **PipelineParallel (PP)** - the model is split up vertically (layer-level) across multiple GPUs, so that only one or several layers of the model are places on a single gpu. Each gpu processes in parallel different stages of the pipeline and working on a small chunk of the batch. -4. **Zero Redundancy Optimizer (ZeRO)** - Also performs sharding of the tensors somewhat similar to TP, except the whole tensor gets reconstructed in time for a forward or backward computation, therefore the model doesn't need to be modified. It also supports various offloading techniques to compensate for limited GPU memory. -5. **Sharded DDP** - is another name for the foundational ZeRO concept as used by various other implementations of ZeRO. +**Parallelization strategy for a single Node / multi-GPU setup** -Before diving deeper into the specifics of each concept we first have a look at the rough decision process when training large models on a large infrastructure. +When training a model on a single node with multiple GPUs, your choice of parallelization strategy can significantly +impact performance. Here's a breakdown of your options: -## Scalability Strategy +**Case 1: Your model fits onto a single GPU** -**⇨ Single Node / Multi-GPU** -* Model fits onto a single GPU: +If your model can comfortably fit onto a single GPU, you have two primary options: - 1. DDP - Distributed DP - 2. ZeRO - may or may not be faster depending on the situation and configuration used +1. DDP - Distributed DataParallel +2. ZeRO - depending on the situation and configuration used, this method may or may not be faster, however, it's worth experimenting with it. -* Model doesn't fit onto a single GPU: +**Case 2: Your model doesn't fit onto a single GPU:** - 1. PP - 2. ZeRO - 3. TP +If your model is too large for a single GPU, you have several alternatives to consider: - With very fast intra-node connectivity of NVLINK or NVSwitch all three should be mostly on par, without these PP will be faster than TP or ZeRO. The degree of TP may also make a difference. Best to experiment to find the winner on your particular setup. +1. PipelineParallel (PP) +2. ZeRO +3. TensorParallel (TP) - TP is almost always used within a single node. That is TP size <= gpus per node. +With very fast inter-node connectivity (e.g., NVLINK or NVSwitch) all three strategies (PP, ZeRO, TP) should result in +similar performance. However, without these, PP will be faster than TP or ZeRO. The degree of TP may also +make a difference. It's best to experiment with your specific setup to determine the most suitable strategy. -* Largest Layer not fitting into a single GPU: +TP is almost always used within a single node. That is TP size <= GPUs per node. - 1. If not using ZeRO - must use TP, as PP alone won't be able to fit. - 2. With ZeRO see the same entry for "Single GPU" above +**Case 3: Largest layer of your model does not fit onto a single GPU** +1. If you are not using ZeRO, you have to use TensorParallel (TP), because PipelineParallel (PP) alone won't be sufficient to accommodate the large layer. +2. If you are using ZeRO, additionally adopt techniques from the [Methods and tools for efficient training on a single GPU](perf_train_gpu_one). -**⇨ Multi-Node / Multi-GPU** +**Parallelization strategy for a multi-Node / multi-GPU setup** -* When you have fast inter-node connectivity: +* When you have fast inter-node connectivity (e.g., NVLINK or NVSwitch) consider using one of these options: 1. ZeRO - as it requires close to no modifications to the model - 2. PP+TP+DP - less communications, but requires massive changes to the model - -* when you have slow inter-node connectivity and still low on GPU memory: + 2. A combination of PipelineParallel(PP) with TensorParallel(TP) and DataParallel(DP) - this approach will result in fewer communications, but requires significant changes to the model - 1. DP+PP+TP+ZeRO-1 +* When you have slow inter-node connectivity and still low on GPU memory: + 1. Employ a combination of DataParallel(DP) with PipelineParallel(PP), TensorParallel(TP), and ZeRO. +In the following sections of this guide we dig deeper into how these different parallelism methods work. ## Data Parallelism -Most users with just 2 GPUs already enjoy the increased training speed up thanks to `DataParallel` (DP) and `DistributedDataParallel` (DDP) that are almost trivial to use. This is a built-in feature of Pytorch. Note that in general it is advised to use DDP as it is better maintained and works for all models while DP might fail for some models. [PyTorch documentation](https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html) itself recommends the use of DDP. +Even with only 2 GPUs, you can readily leverage the accelerated training capabilities offered by PyTorch's built-in features, +such as `DataParallel` (DP) and `DistributedDataParallel` (DDP). Note that +[PyTorch documentation](https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html) recommends to prefer +`DistributedDataParallel` (DDP) over `DataParallel` (DP) for multi-GPU training as it works for all models. +Let's take a look at how these two methods work and what makes them different. -### DP vs DDP +### DataParallel vs DistributedDataParallel -`DistributedDataParallel` (DDP) is typically faster than `DataParallel` (DP), but it is not always the case: -* while DP is python threads-based, DDP is multiprocess-based - and as such it has no python threads limitations, such as GIL -* on the other hand a slow inter-connectivity between the GPU cards could lead to an actual slower outcome with DDP - -Here are the main differences in the inter-GPU communication overhead between the two modes: +To understand the key differences in inter-GPU communication overhead between the two methods, let's review the processes per batch: [DDP](https://pytorch.org/docs/master/notes/ddp.html): -- At the start time the main process replicates the model once from gpu 0 to the rest of gpus +- At the start time the main process replicates the model once from GPU 0 to the rest of GPUs - Then for each batch: - 1. each gpu consumes each own mini-batch of data directly - 2. during `backward`, once the local gradients are ready, they are then averaged across all processes + 1. Each GPU directly consumes its mini-batch of data. + 2. During `backward`, once the local gradients are ready, they are averaged across all processes. [DP](https://pytorch.org/docs/master/generated/torch.nn.DataParallel.html): For each batch: - 1. gpu 0 reads the batch of data and then sends a mini-batch to each gpu - 2. replicates the up-to-date model from gpu 0 to each gpu - 3. runs `forward` and sends output from each gpu to gpu 0, computes loss - 4. scatters loss from gpu 0 to all gpus, runs `backward` - 5. sends gradients from each gpu to gpu 0 and averages those - -The only communication DDP performs per batch is sending gradients, whereas DP does 5 different data exchanges per batch. - -DP copies data within the process via python threads, whereas DDP copies data via [torch.distributed](https://pytorch.org/docs/master/distributed.html). + 1. GPU 0 reads the batch of data and then sends a mini-batch to each GPU. + 2. The up-to-date model is replicated from GPU 0 to each GPU. + 3. `forward` is executed, and output from each GPU is sent to GPU 0 to compute the loss. + 4. The loss is distributed from GPU 0 to all GPUs, and `backward` is run. + 5. Gradients from each GPU are sent to GPU 0 and averaged. -Under DP gpu 0 performs a lot more work than the rest of the gpus, thus resulting in under-utilization of gpus. - -You can use DDP across multiple machines, but this is not the case with DP. - -There are other differences between DP and DDP but they aren't relevant to this discussion. - -If you want to go really deep into understanding these 2 modes, this [article](https://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/) is highly recommended, as it has great diagrams, includes multiple benchmarks and profiler outputs on various hardware, explains all the nuances that you may need to know. - -Let's look at an actual benchmark: - -| Type | NVlink | Time | -| :----- | ----- | ---: | -| 2:DP | Y | 110s | -| 2:DDP | Y | 101s | -| 2:DDP | N | 131s | +Key differences include: +1. DDP performs only a single communication per batch - sending gradients, while DP performs five different data exchanges per batch. +DDP copies data using [torch.distributed](https://pytorch.org/docs/master/distributed.html), while DP copies data within +the process via Python threads (which introduces limitations associated with GIL). As a result, **`DistributedDataParallel` (DDP) is generally faster than `DataParallel` (DP)** unless you have slow GPU card inter-connectivity. +2. Under DP, GPU 0 performs significantly more work than other GPUs, resulting in GPU under-utilization. +3. DDP supports distributed training across multiple machines, whereas DP does not. +This is not an exhaustive list of differences between DP and DDP, however, other nuances are out of scope of this guide. +You can get a deeper understanding of these methods by reading this [article](https://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/). -Analysis: +Let's illustrate the differences between DP and DDP with an experiment. We'll benchmark the differences between DP and +DDP with an added context of NVLink presence: -Here DP is ~10% slower than DDP w/ NVlink, but ~15% faster than DDP w/o NVlink +* Hardware: 2x TITAN RTX 24GB each + NVlink with 2 NVLinks (`NV2` in `nvidia-smi topo -m`). +* Software: `pytorch-1.8-to-be` + `cuda-11.0` / `transformers==4.3.0.dev0`. -The real difference will depend on how much data each GPU needs to sync with the others - the more there is to sync, the more a slow link will slow down the total runtime. +To disable the NVLink feature on one of the benchmarks, we use `NCCL_P2P_DISABLE=1`. -Here is the full benchmark code and outputs: +Here is the benchmarking code and outputs: -`NCCL_P2P_DISABLE=1` was used to disable the NVLink feature on the corresponding benchmark. +**DP** ``` - -# DP rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 \ python examples/pytorch/language-modeling/run_clm.py \ --model_name_or_path gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \ --do_train --output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200 {'train_runtime': 110.5948, 'train_samples_per_second': 1.808, 'epoch': 0.69} +``` -# DDP w/ NVlink +**DDP w/ NVlink** + +``` rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 \ python -m torch.distributed.launch --nproc_per_node 2 examples/pytorch/language-modeling/run_clm.py \ --model_name_or_path gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \ --do_train --output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200 {'train_runtime': 101.9003, 'train_samples_per_second': 1.963, 'epoch': 0.69} +``` -# DDP w/o NVlink +**DDP w/o NVlink** + +``` rm -r /tmp/test-clm; NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1 \ python -m torch.distributed.launch --nproc_per_node 2 examples/pytorch/language-modeling/run_clm.py \ --model_name_or_path gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \ @@ -160,17 +171,34 @@ python -m torch.distributed.launch --nproc_per_node 2 examples/pytorch/language- {'train_runtime': 131.4367, 'train_samples_per_second': 1.522, 'epoch': 0.69} ``` -Hardware: 2x TITAN RTX 24GB each + NVlink with 2 NVLinks (`NV2` in `nvidia-smi topo -m`) -Software: `pytorch-1.8-to-be` + `cuda-11.0` / `transformers==4.3.0.dev0` +Here are the same benchmarking results gathered in a table for convenience: + +| Type | NVlink | Time | +| :----- | ----- | ---: | +| 2:DP | Y | 110s | +| 2:DDP | Y | 101s | +| 2:DDP | N | 131s | + +As you can see, in this case DP is ~10% slower than DDP with NVlink, but ~15% faster than DDP without NVlink. +The real difference will depend on how much data each GPU needs to sync with the others - the more there is to sync, +the more a slow link will impede the overall runtime. ## ZeRO Data Parallelism -ZeRO-powered data parallelism (ZeRO-DP) is described on the following diagram from this [blog post](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/) -![DeepSpeed-Image-1](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-zero.png) +ZeRO-powered data parallelism (ZeRO-DP) is illustrated in the following diagram from this [blog post](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/). -It can be difficult to wrap one's head around it, but in reality the concept is quite simple. This is just the usual `DataParallel` (DP), except, instead of replicating the full model params, gradients and optimizer states, each GPU stores only a slice of it. And then at run-time when the full layer params are needed just for the given layer, all GPUs synchronize to give each other parts that they miss - this is it. +
+ DeepSpeed-Image-1 +
+ +While it may appear complex, it is a very similar concept to `DataParallel` (DP). The difference is that instead of +replicating the full model parameters, gradients and optimizer states, each GPU stores only a slice of it. Then, at +run-time when the full layer parameters are needed just for the given layer, all GPUs synchronize to give each other +parts that they miss. + +To illustrate this idea, consider a simple model with 3 layers (La, Lb, and Lc), where each layer has 3 parameters. +Layer La, for example, has weights a0, a1 and a2: -Consider this simple model with 3 layers, where each layer has 3 params: ``` La | Lb | Lc ---|----|--- @@ -178,9 +206,8 @@ a0 | b0 | c0 a1 | b1 | c1 a2 | b2 | c2 ``` -Layer La has weights a0, a1 and a2. -If we have 3 GPUs, the Sharded DDP (= Zero-DP) splits the model onto 3 GPUs like so: +If we have 3 GPUs, ZeRO-DP splits the model onto 3 GPUs like so: ``` GPU0: @@ -199,165 +226,213 @@ La | Lb | Lc a2 | b2 | c2 ``` -In a way this is the same horizontal slicing, as tensor parallelism, if you imagine the typical DNN diagram. Vertical slicing is where one puts whole layer-groups on different GPUs. But it's just the starting point. +In a way, this is the same horizontal slicing as tensor parallelism, as opposed to Vertical +slicing, where one puts whole layer-groups on different GPUs. Now let's see how this works: + +Each of these GPUs will get the usual mini-batch as it works in DP: -Now each of these GPUs will get the usual mini-batch as it works in DP: ``` x0 => GPU0 x1 => GPU1 x2 => GPU2 ``` -The inputs are unmodified - they think they are going to be processed by the normal model. - -First, the inputs hit the layer La. - -Let's focus just on GPU0: x0 needs a0, a1, a2 params to do its forward path, but GPU0 has only a0 - it gets sent a1 from GPU1 and a2 from GPU2, bringing all pieces of the model together. - -In parallel, GPU1 gets mini-batch x1 and it only has a1, but needs a0 and a2 params, so it gets those from GPU0 and GPU2. +The inputs are passed without modifications as if they would be processed by the original model. -Same happens to GPU2 that gets input x2. It gets a0 and a1 from GPU0 and GPU1, and with its a2 it reconstructs the full tensor. +First, the inputs get to the layer `La`. What happens at this point? -All 3 GPUs get the full tensors reconstructed and a forward happens. +On GPU0: the x0 mini-batch requires the a0, a1, a2 parameters to do its forward path through the layer, but the GPU0 has only a0. +It will get a1 from GPU1 and a2 from GPU2, bringing all the pieces of the model together. -As soon as the calculation is done, the data that is no longer needed gets dropped - it's only used during the calculation. The reconstruction is done efficiently via a pre-fetch. +In parallel, GPU1 gets another mini-batch - x1. GPU1 has the a1 parameter, but needs a0 and a2, so it gets those from GPU0 and GPU2. +Same happens to GPU2 that gets the mini-batch x2. It gets a0 and a1 from GPU0 and GPU1. -And the whole process is repeated for layer Lb, then Lc forward-wise, and then backward Lc -> Lb -> La. +This way each of the 3 GPUs gets the full tensors reconstructed and makes a forward pass with its own mini-batch. +As soon as the calculation is done, the data that is no longer needed gets dropped - it's only used during the calculation. +The reconstruction is done efficiently via a pre-fetch. -To me this sounds like an efficient group backpacking weight distribution strategy: +Then the whole process is repeated for layer Lb, then Lc forward-wise, and then backward Lc -> Lb -> La. -1. person A carries the tent -2. person B carries the stove -3. person C carries the axe + -Now each night they all share what they have with others and get from others what they don't have, and in the morning they pack up their allocated type of gear and continue on their way. This is Sharded DDP / Zero DP. +This mechanism is similar to an efficient group backpacking strategy: person A carries the tent, person B carries the stove, +and person C carries the axe. Each night they all share what they have with others and get from others what they don't have, +and in the morning they pack up their allocated type of gear and continue on their way. This is what ZeRO DP/Sharded DDP is. +Compare this strategy to the simple one where each person has to carry their own tent, stove and axe (similar to +DataParallel (DP and DDP) in PyTorch), which would be far more inefficient. -Compare this strategy to the simple one where each person has to carry their own tent, stove and axe, which would be far more inefficient. This is DataParallel (DP and DDP) in Pytorch. + While reading the literature on this topic you may encounter the following synonyms: Sharded, Partitioned. - -If you pay close attention the way ZeRO partitions the model's weights - it looks very similar to tensor parallelism which will be discussed later. This is because it partitions/shards each layer's weights, unlike vertical model parallelism which is discussed next. +If you pay close attention the way ZeRO partitions the model's weights - it looks very similar to tensor parallelism +which will be discussed later. This is because it partitions/shards each layer's weights, unlike vertical model parallelism +which is discussed next. Implementations: - [DeepSpeed](https://www.deepspeed.ai/features/#the-zero-redundancy-optimizer) ZeRO-DP stages 1+2+3 +- [`Accelerate` integration](https://huggingface.co/docs/accelerate/en/usage_guides/deepspeed) - [`transformers` integration](main_classes/trainer#trainer-integrations) -## Naive Model Parallelism (Vertical) and Pipeline Parallelism +## From Naive Model Parallelism to Pipeline Parallelism -Naive Model Parallelism (MP) is where one spreads groups of model layers across multiple GPUs. The mechanism is relatively simple - switch the desired layers `.to()` the desired devices and now whenever the data goes in and out those layers switch the data to the same device as the layer and leave the rest unmodified. +To explain Pipeline parallelism, we'll first look into Naive Model Parallelism (MP), also known as Vertical MP. This approach +involves distributing groups of model layers across multiple GPUs by assigning specific layers to specific GPUs with `.to()`. +As data flows through these layers, it is moved to the same GPU as the layer, while the other layers remain untouched. -We refer to it as Vertical MP, because if you remember how most models are drawn, we slice the layers vertically. For example, if the following diagram shows an 8-layer model: +We refer to this Model parallelism as "Vertical" because of how models are typically visualized. For example, the +following diagram shows an 8-layer model split vertically into two slices, placing layers 0-3 onto +GPU0 and 4-7 to GPU1: ``` =================== =================== | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | =================== =================== - gpu0 gpu1 + GPU0 GPU1 ``` -we just sliced it in 2 vertically, placing layers 0-3 onto GPU0 and 4-7 to GPU1. - -Now while data travels from layer 0 to 1, 1 to 2 and 2 to 3 this is just the normal model. But when data needs to pass from layer 3 to layer 4 it needs to travel from GPU0 to GPU1 which introduces a communication overhead. If the participating GPUs are on the same compute node (e.g. same physical machine) this copying is pretty fast, but if the GPUs are located on different compute nodes (e.g. multiple machines) the communication overhead could be significantly larger. - -Then layers 4 to 5 to 6 to 7 are as a normal model would have and when the 7th layer completes we often need to send the data back to layer 0 where the labels are (or alternatively send the labels to the last layer). Now the loss can be computed and the optimizer can do its work. - -Problems: -- the main deficiency and why this one is called "naive" MP, is that all but one GPU is idle at any given moment. So if 4 GPUs are used, it's almost identical to quadrupling the amount of memory of a single GPU, and ignoring the rest of the hardware. Plus there is the overhead of copying the data between devices. So 4x 6GB cards will be able to accommodate the same size as 1x 24GB card using naive MP, except the latter will complete the training faster, since it doesn't have the data copying overhead. But, say, if you have 40GB cards and need to fit a 45GB model you can with 4x 40GB cards (but barely because of the gradient and optimizer states) -- shared embeddings may need to get copied back and forth between GPUs. - -Pipeline Parallelism (PP) is almost identical to a naive MP, but it solves the GPU idling problem, by chunking the incoming batch into micro-batches and artificially creating a pipeline, which allows different GPUs to concurrently participate in the computation process. - -The following illustration from the [GPipe paper](https://ai.googleblog.com/2019/03/introducing-gpipe-open-source-library.html) shows the naive MP on the top, and PP on the bottom: - -![mp-pp](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-gpipe-bubble.png) - -It's easy to see from the bottom diagram how PP has less dead zones, where GPUs are idle. The idle parts are referred to as the "bubble". - -Both parts of the diagram show a parallelism that is of degree 4. That is 4 GPUs are participating in the pipeline. So there is the forward path of 4 pipe stages F0, F1, F2 and F3 and then the return reverse order backward path of B3, B2, B1 and B0. -PP introduces a new hyper-parameter to tune and it's `chunks` which defines how many chunks of data are sent in a sequence through the same pipe stage. For example, in the bottom diagram you can see that `chunks=4`. GPU0 performs the same forward path on chunk 0, 1, 2 and 3 (F0,0, F0,1, F0,2, F0,3) and then it waits for other GPUs to do their work and only when their work is starting to be complete, GPU0 starts to work again doing the backward path for chunks 3, 2, 1 and 0 (B0,3, B0,2, B0,1, B0,0). - -Note that conceptually this is the same concept as gradient accumulation steps (GAS). Pytorch uses `chunks`, whereas DeepSpeed refers to the same hyper-parameter as GAS. - -Because of the chunks, PP introduces the concept of micro-batches (MBS). DP splits the global data batch size into mini-batches, so if you have a DP degree of 4, a global batch size of 1024 gets split up into 4 mini-batches of 256 each (1024/4). And if the number of `chunks` (or GAS) is 32 we end up with a micro-batch size of 8 (256/32). Each Pipeline stage works with a single micro-batch at a time. - -To calculate the global batch size of the DP + PP setup we then do: `mbs*chunks*dp_degree` (`8*32*4=1024`). - -Let's go back to the diagram. - -With `chunks=1` you end up with the naive MP, which is very inefficient. With a very large `chunks` value you end up with tiny micro-batch sizes which could be not every efficient either. So one has to experiment to find the value that leads to the highest efficient utilization of the gpus. - -While the diagram shows that there is a bubble of "dead" time that can't be parallelized because the last `forward` stage has to wait for `backward` to complete the pipeline, the purpose of finding the best value for `chunks` is to enable a high concurrent GPU utilization across all participating GPUs which translates to minimizing the size of the bubble. - -There are 2 groups of solutions - the traditional Pipeline API and the more modern solutions that make things much easier for the end user. - -Traditional Pipeline API solutions: +In this example, when data moves from layer 0 to 3, it's no different from regular forward pass. However, passing data +from layer 3 to 4 requires moving it from GPU0 to GPU1, introducing a communication overhead. If the participating +GPUs are on the same compute node (e.g. same physical machine) this copying is fast, but if the GPUs are distributed +across different compute nodes (e.g. multiple machines), the communication overhead could be substantially greater. + +Following that, layers 4 to 7 work as they would in the original model. Upon completion of the 7th layer, there is often +a need to send the data back to layer 0 where the labels are (or alternatively send the labels to the last layer). Now the loss can be +computed and the optimizer can do its work. + +Naive Model Parallelism comes several shortcomings: +- **All but one GPU are idle at any given moment**: if 4 GPUs are used, it's nearly identical to quadrupling the amount of memory of a single GPU, and ignoring the rest of the hardware. +- **Overhead in data transfer between devices**: E.g. 4x 6GB cards will be able to accommodate the same size as 1x 24GB card using naive MP, but a single 24GB card will complete the training faster, because it doesn't have the data copying overhead. But, say, if you have 40GB cards and need to fit a 45GB model you can with 4x 40GB cards (but barely because of the gradient and optimizer states) +- **Copying shared embeddings**: Shared embeddings may need to get copied back and forth between GPUs. + +Now that you are familiar with how the naive approach to model parallelism works and its shortcomings, let's look at Pipeline Parallelism (PP). +PP is almost identical to a naive MP, but it solves the GPU idling problem by chunking the incoming batch into micro-batches +and artificially creating a pipeline, which allows different GPUs to concurrently participate in the computation process. + +The following illustration from the [GPipe paper](https://ai.googleblog.com/2019/03/introducing-gpipe-open-source-library.html) +shows the naive MP on the top, and PP on the bottom: + +
+ MP vs PP +
+ +At the bottom of the diagram, you can observe that the Pipeline Parallelism (PP) approach minimizes the number of idle +GPU zones, referred to as 'bubbles'. Both parts of the diagram show a parallelism level of degree 4, meaning that 4 GPUs +are involved in the pipeline. You can see that there's a forward path of 4 pipe stages (F0, F1, F2 and F3) followed by +a backward path in reverse order (B3, B2, B1, and B0). + +PP introduces a new hyperparameter to tune - `chunks`, which determines how many data chunks are sent in a sequence +through the same pipe stage. For example, in the bottom diagram you can see `chunks=4`. GPU0 performs the same +forward path on chunk 0, 1, 2 and 3 (F0,0, F0,1, F0,2, F0,3) and then it waits for other GPUs to do complete their work. +Only when the other GPUs begin to complete their work, GPU0 starts to work again doing the backward path for chunks +3, 2, 1 and 0 (B0,3, B0,2, B0,1, B0,0). + +Note that this is the same concept as gradient accumulation steps. PyTorch uses `chunks`, while DeepSpeed refers +to the same hyperparameter as gradient accumulation steps. + +Because of the chunks, PP introduces the notion of micro-batches (MBS). DP splits the global data batch size into +mini-batches, so if you have a DP degree of 4, a global batch size of 1024 gets split up into 4 mini-batches of +256 each (1024/4). And if the number of `chunks` (or GAS) is 32 we end up with a micro-batch size of 8 (256/32). Each +Pipeline stage works with a single micro-batch at a time. To calculate the global batch size of the DP + PP setup, +use the formula: `mbs * chunks * dp_degree` (`8 * 32 * 4 = 1024`). +With `chunks=1` you end up with the naive MP, which is inefficient. With a large `chunks` value you end up with +tiny micro-batch sizes which is also inefficient. For this reason, we encourage to experiment with the `chunks` value to +find the one that leads to the most efficient GPUs utilization. + +You may notice a bubble of "dead" time on the diagram that can't be parallelized because the last `forward` stage +has to wait for `backward` to complete the pipeline. The purpose of finding the best value for `chunks` is to enable a high +concurrent GPU utilization across all participating GPUs which translates to minimizing the size of the bubble. + +Pipeline API solutions have been implemented in: - PyTorch - DeepSpeed - Megatron-LM -Modern solutions: +These come with some shortcomings: +- They have to modify the model quite heavily, because Pipeline requires one to rewrite the normal flow of modules into a `nn.Sequential` sequence of the same, which may require changes to the design of the model. +- Currently the Pipeline API is very restricted. If you had a bunch of Python variables being passed in the very first stage of the Pipeline, you will have to find a way around it. Currently, the pipeline interface requires either a single Tensor or a tuple of Tensors as the only input and output. These tensors must have a batch size as the very first dimension, since pipeline is going to chunk the mini batch into micro-batches. Possible improvements are being discussed here https://github.com/pytorch/pytorch/pull/50693 +- Conditional control flow at the level of pipe stages is not possible - e.g., Encoder-Decoder models like T5 require special workarounds to handle a conditional encoder stage. +- They have to arrange each layer so that the output of one layer becomes an input to the other layer. + +More recent solutions include: - Varuna - Sagemaker -Problems with traditional Pipeline API solutions: -- have to modify the model quite heavily, because Pipeline requires one to rewrite the normal flow of modules into a `nn.Sequential` sequence of the same, which may require changes to the design of the model. -- currently the Pipeline API is very restricted. If you had a bunch of python variables being passed in the very first stage of the Pipeline, you will have to find a way around it. Currently, the pipeline interface requires either a single Tensor or a tuple of Tensors as the only input and output. These tensors must have a batch size as the very first dimension, since pipeline is going to chunk the mini batch into micro-batches. Possible improvements are being discussed here https://github.com/pytorch/pytorch/pull/50693 -- conditional control flow at the level of pipe stages is not possible - e.g., Encoder-Decoder models like T5 require special workarounds to handle a conditional encoder stage. -- have to arrange each layer so that the output of one model becomes an input to the other model. - -We are yet to experiment with Varuna and SageMaker but their papers report that they have overcome the list of problems mentioned above and that they require much smaller changes to the user's model. +We have not experimented with Varuna and SageMaker but their papers report that they have overcome the list of problems +mentioned above and that they require smaller changes to the user's model. Implementations: -- [Pytorch](https://pytorch.org/docs/stable/pipeline.html) (initial support in pytorch-1.8, and progressively getting improved in 1.9 and more so in 1.10). Some [examples](https://github.com/pytorch/pytorch/blob/master/benchmarks/distributed/pipeline/pipe.py) +- [PyTorch](https://pytorch.org/docs/stable/pipeline.html) (initial support in pytorch-1.8, and progressively getting improved in 1.9 and more so in 1.10). Some [examples](https://github.com/pytorch/pytorch/blob/master/benchmarks/distributed/pipeline/pipe.py) - [DeepSpeed](https://www.deepspeed.ai/tutorials/pipeline/) - [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) has an internal implementation - no API. - [Varuna](https://github.com/microsoft/varuna) - [SageMaker](https://arxiv.org/abs/2111.05972) - this is a proprietary solution that can only be used on AWS. - [OSLO](https://github.com/tunib-ai/oslo) - this is implemented based on the Hugging Face Transformers. -🤗 Transformers status: as of this writing none of the models supports full-PP. GPT2 and T5 models have naive MP support. The main obstacle is being unable to convert the models to `nn.Sequential` and have all the inputs to be Tensors. This is because currently the models include many features that make the conversion very complicated, and will need to be removed to accomplish that. +🤗 Transformers status: as of this writing none of the models supports full-PP. GPT2 and T5 models have naive MP support. +The main obstacle is being unable to convert the models to `nn.Sequential` and have all the inputs to be Tensors. This +is because currently the models include many features that make the conversion very complicated, and will need to be removed to accomplish that. + +DeepSpeed and Megatron-LM integrations are available in [🤗 Accelerate](https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed) Other approaches: DeepSpeed, Varuna and SageMaker use the concept of an [Interleaved Pipeline](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-core-features.html) -![interleaved-pipeline-execution](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-sagemaker-interleaved-pipeline.png) -Here the bubble (idle time) is further minimized by prioritizing backward passes. +
+ Interleaved pipeline execution +
-Varuna further tries to improve the schedule by using simulations to discover the most efficient scheduling. +Here the bubble (idle time) is further minimized by prioritizing backward passes. Varuna further attempts to improve the +schedule by using simulations to discover the most efficient scheduling. -OSLO has pipeline parallelism implementation based on the Transformers without `nn.Sequential` converting. +OSLO has pipeline parallelism implementation based on the Transformers without `nn.Sequential` conversion. ## Tensor Parallelism -In Tensor Parallelism each GPU processes only a slice of a tensor and only aggregates the full tensor for operations that require the whole thing. - -In this section we use concepts and diagrams from the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) paper: [Efficient Large-Scale Language Model Training on GPU Clusters](https://arxiv.org/abs/2104.04473). +In Tensor Parallelism, each GPU processes a slice of a tensor and only aggregates the full tensor for operations requiring it. +To describe this method, this section of the guide relies on the concepts and diagrams from the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) +paper: [Efficient Large-Scale Language Model Training on GPU Clusters](https://arxiv.org/abs/2104.04473). The main building block of any transformer is a fully connected `nn.Linear` followed by a nonlinear activation `GeLU`. +The dot dot-product part of it, following the Megatron's paper notation, can be written as `Y = GeLU(XA)`, where `X` is +an input vector, `Y` is the output vector, and `A` is the weight matrix. -Following the Megatron's paper notation, we can write the dot-product part of it as `Y = GeLU(XA)`, where `X` and `Y` are the input and output vectors, and `A` is the weight matrix. +If we look at the computation in matrix form, you can see how the matrix multiplication can be split between multiple GPUs: -If we look at the computation in matrix form, it's easy to see how the matrix multiplication can be split between multiple GPUs: -![Parallel GEMM](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-tp-parallel_gemm.png) +
+ Parallel GEMM +
-If we split the weight matrix `A` column-wise across `N` GPUs and perform matrix multiplications `XA_1` through `XA_n` in parallel, then we will end up with `N` output vectors `Y_1, Y_2, ..., Y_n` which can be fed into `GeLU` independently: -![independent GeLU](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-tp-independent-gelu.png) +If we split the weight matrix `A` column-wise across `N` GPUs and perform matrix multiplications `XA_1` through `XA_n` in parallel, +then we will end up with `N` output vectors `Y_1, Y_2, ..., Y_n` which can be fed into `GeLU` independently: -Using this principle, we can update an MLP of arbitrary depth, without the need for any synchronization between GPUs until the very end, where we need to reconstruct the output vector from shards. The Megatron-LM paper authors provide a helpful illustration for that: -![parallel shard processing](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-tp-parallel_shard_processing.png) +
+ Independent GeLU +
-Parallelizing the multi-headed attention layers is even simpler, since they are already inherently parallel, due to having multiple independent heads! -![parallel self-attention](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-tp-parallel_self_attention.png) +Using this principle, we can update a multi-layer perceptron of arbitrary depth, without the need for any synchronization +between GPUs until the very end, where we need to reconstruct the output vector from shards. The Megatron-LM paper authors +provide a helpful illustration for that: -Special considerations: TP requires very fast network, and therefore it's not advisable to do TP across more than one node. Practically, if a node has 4 GPUs, the highest TP degree is therefore 4. If you need a TP degree of 8, you need to use nodes that have at least 8 GPUs. +
+ Parallel shard processing +
+ +Parallelizing the multi-headed attention layers is even simpler, since they are already inherently parallel, due to having +multiple independent heads! + +
+ Parallel self-attention +
+ +Special considerations: TP requires very fast network, and therefore it's not advisable to do TP across more than one node. +Practically, if a node has 4 GPUs, the highest TP degree is therefore 4. If you need a TP degree of 8, you need to use +nodes that have at least 8 GPUs. This section is based on the original much more [detailed TP overview](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530). by [@anton-l](https://github.com/anton-l). -SageMaker combines TP with DP for a more efficient processing. - Alternative names: - DeepSpeed calls it [tensor slicing](https://www.deepspeed.ai/features/#model-parallelism) @@ -367,18 +442,27 @@ Implementations: - [SageMaker](https://arxiv.org/abs/2111.05972) - this is a proprietary solution that can only be used on AWS. - [OSLO](https://github.com/tunib-ai/oslo) has the tensor parallelism implementation based on the Transformers. +SageMaker combines TP with DP for a more efficient processing. + 🤗 Transformers status: - core: not yet implemented in the core - but if you want inference [parallelformers](https://github.com/tunib-ai/parallelformers) provides this support for most of our models. So until this is implemented in the core you can use theirs. And hopefully training mode will be supported too. - Deepspeed-Inference also supports our BERT, GPT-2, and GPT-Neo models in their super-fast CUDA-kernel-based inference mode, see more [here](https://www.deepspeed.ai/tutorials/inference-tutorial/) -## DP+PP +🤗 Accelerate integrates with [TP from Megatron-LM](https://huggingface.co/docs/accelerate/v0.23.0/en/usage_guides/megatron_lm). + +## Data Parallelism + Pipeline Parallelism -The following diagram from the DeepSpeed [pipeline tutorial](https://www.deepspeed.ai/tutorials/pipeline/) demonstrates how one combines DP with PP. +The following diagram from the DeepSpeed [pipeline tutorial](https://www.deepspeed.ai/tutorials/pipeline/) demonstrates +how one can combine DP with PP. -![dp-pp-2d](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-zero-dp-pp.png) +
+ DP + PP-2d +
-Here it's important to see how DP rank 0 doesn't see GPU2 and DP rank 1 doesn't see GPU3. To DP there is just GPUs 0 and 1 where it feeds data as if there were just 2 GPUs. GPU0 "secretly" offloads some of its load to GPU2 using PP. And GPU1 does the same by enlisting GPU3 to its aid. +Here it's important to see how DP rank 0 doesn't see GPU2 and DP rank 1 doesn't see GPU3. To DP there is just GPUs 0 +and 1 where it feeds data as if there were just 2 GPUs. GPU0 "secretly" offloads some of its load to GPU2 using PP. +And GPU1 does the same by enlisting GPU3 to its aid. Since each dimension requires at least 2 GPUs, here you'd need at least 4 GPUs. @@ -391,11 +475,13 @@ Implementations: 🤗 Transformers status: not yet implemented -## DP+PP+TP +## Data Parallelism + Pipeline Parallelism + Tensor Parallelism To get an even more efficient training a 3D parallelism is used where PP is combined with TP and DP. This can be seen in the following diagram. -![dp-pp-tp-3d](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-deepspeed-3d.png) +
+ dp-pp-tp-3d +
This diagram is from a blog post [3D parallelism: Scaling to trillion-parameter models](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/), which is a good read as well. @@ -410,15 +496,22 @@ Implementations: 🤗 Transformers status: not yet implemented, since we have no PP and TP. -## ZeRO DP+PP+TP +## ZeRO Data Parallelism + Pipeline Parallelism + Tensor Parallelism -One of the main features of DeepSpeed is ZeRO, which is a super-scalable extension of DP. It has already been discussed in [ZeRO Data Parallelism](#zero-data-parallelism). Normally it's a standalone feature that doesn't require PP or TP. But it can be combined with PP and TP. +One of the main features of DeepSpeed is ZeRO, which is a super-scalable extension of DP. It has already been +discussed in [ZeRO Data Parallelism](#zero-data-parallelism). Normally it's a standalone feature that doesn't require PP or TP. +But it can be combined with PP and TP. When ZeRO-DP is combined with PP (and optionally TP) it typically enables only ZeRO stage 1 (optimizer sharding). -While it's theoretically possible to use ZeRO stage 2 (gradient sharding) with Pipeline Parallelism, it will have bad performance impacts. There would need to be an additional reduce-scatter collective for every micro-batch to aggregate the gradients before sharding, which adds a potentially significant communication overhead. By nature of Pipeline Parallelism, small micro-batches are used and instead the focus is on trying to balance arithmetic intensity (micro-batch size) with minimizing the Pipeline bubble (number of micro-batches). Therefore those communication costs are going to hurt. +While it's theoretically possible to use ZeRO stage 2 (gradient sharding) with Pipeline Parallelism, it will have negative +performance impacts. There would need to be an additional reduce-scatter collective for every micro-batch to aggregate +the gradients before sharding, which adds a potentially significant communication overhead. By nature of Pipeline Parallelism, +small micro-batches are used and instead the focus is on trying to balance arithmetic intensity (micro-batch size) with +minimizing the Pipeline bubble (number of micro-batches). Therefore those communication costs are going to impact the performance. -In addition, There are already fewer layers than normal due to PP and so the memory savings won't be huge. PP already reduces gradient size by ``1/PP``, and so gradient sharding savings on top of that are less significant than pure DP. +In addition, there are already fewer layers than normal due to PP and so the memory savings won't be huge. PP already +reduces gradient size by ``1/PP``, and so gradient sharding savings on top of that are less significant than pure DP. ZeRO stage 3 is not a good choice either for the same reason - more inter-node communications required. @@ -455,7 +548,9 @@ Let's take 10 batches of sequence length 512. If we parallelize them by sample d * Operator -If we perform layer normalization, we compute std first and mean second, and then we can normalize data. Operator parallelism allows computing std and mean in parallel. So if we parallelize them by operator dimension into 2 devices (cuda:0, cuda:1), first we copy input data into both devices, and cuda:0 computes std, cuda:1 computes mean at the same time. +If we perform layer normalization, we compute std first and mean second, and then we can normalize data. +Operator parallelism allows computing std and mean in parallel. So if we parallelize them by operator dimension into 2 +devices (cuda:0, cuda:1), first we copy input data into both devices, and cuda:0 computes std, cuda:1 computes mean at the same time. * Attribute @@ -465,66 +560,20 @@ We have 10 batches of 512 length. If we parallelize them by attribute dimension It is similar with tensor model parallelism or naive layer-wise model parallelism. -![flex-flow-soap](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/parallelism-flexflow.jpeg) - -The significance of this framework is that it takes resources like (1) GPU/TPU/CPU vs. (2) RAM/DRAM vs. (3) fast-intra-connect/slow-inter-connect and it automatically optimizes all these algorithmically deciding which parallelisation to use where. - -One very important aspect is that FlexFlow is designed for optimizing DNN parallelizations for models with static and fixed workloads, since models with dynamic behavior may prefer different parallelization strategies across iterations. - -So the promise is very attractive - it runs a 30min simulation on the cluster of choice and it comes up with the best strategy to utilise this specific environment. If you add/remove/replace any parts it'll run and re-optimize the plan for that. And then you can train. A different setup will have its own custom optimization. +
+ flex-flow-soap +
-🤗 Transformers status: not yet integrated. We already have our models FX-trace-able via [transformers.utils.fx](https://github.com/huggingface/transformers/blob/master/src/transformers/utils/fx.py), which is a prerequisite for FlexFlow, so someone needs to figure out what needs to be done to make FlexFlow work with our models. +The significance of this framework is that it takes resources like (1) GPU/TPU/CPU vs. (2) RAM/DRAM vs. (3) +fast-intra-connect/slow-inter-connect and it automatically optimizes all these algorithmically deciding which +parallelisation to use where. +One very important aspect is that FlexFlow is designed for optimizing DNN parallelizations for models with static and +fixed workloads, since models with dynamic behavior may prefer different parallelization strategies across iterations. -## Which Strategy To Use When - -Here is a very rough outline at which parallelism strategy to use when. The first on each list is typically faster. - -**⇨ Single GPU** - -* Model fits onto a single GPU: - - 1. Normal use - -* Model doesn't fit onto a single GPU: - - 1. ZeRO + Offload CPU and optionally NVMe - 2. as above plus Memory Centric Tiling (see below for details) if the largest layer can't fit into a single GPU - -* Largest Layer not fitting into a single GPU: - -1. ZeRO - Enable [Memory Centric Tiling](https://deepspeed.readthedocs.io/en/latest/zero3.html#memory-centric-tiling) (MCT). It allows you to run arbitrarily large layers by automatically splitting them and executing them sequentially. MCT reduces the number of parameters that are live on a GPU, but it does not affect the activation memory. As this need is very rare as of this writing a manual override of `torch.nn.Linear` needs to be done by the user. - -**⇨ Single Node / Multi-GPU** - -* Model fits onto a single GPU: - - 1. DDP - Distributed DP - 2. ZeRO - may or may not be faster depending on the situation and configuration used - -* Model doesn't fit onto a single GPU: - - 1. PP - 2. ZeRO - 3. TP - - With very fast intra-node connectivity of NVLINK or NVSwitch all three should be mostly on par, without these PP will be faster than TP or ZeRO. The degree of TP may also make a difference. Best to experiment to find the winner on your particular setup. - - TP is almost always used within a single node. That is TP size <= gpus per node. - -* Largest Layer not fitting into a single GPU: - - 1. If not using ZeRO - must use TP, as PP alone won't be able to fit. - 2. With ZeRO see the same entry for "Single GPU" above - - -**⇨ Multi-Node / Multi-GPU** - -* When you have fast inter-node connectivity: - - 1. ZeRO - as it requires close to no modifications to the model - 2. PP+TP+DP - less communications, but requires massive changes to the model - -* when you have slow inter-node connectivity and still low on GPU memory: +So the promise is very attractive - it runs a 30min simulation on the cluster of choice and it comes up with the best +strategy to utilise this specific environment. If you add/remove/replace any parts it'll run and re-optimize the plan +for that. And then you can train. A different setup will have its own custom optimization. - 1. DP+PP+TP+ZeRO-1 +🤗 Transformers status: Transformers models are FX-trace-able via [transformers.utils.fx](https://github.com/huggingface/transformers/blob/master/src/transformers/utils/fx.py), +which is a prerequisite for FlexFlow, however, changes are required on the FlexFlow side to make it work with Transformers models. From a0fd34483febc7f3bda04f1e5ce1cea98221e847 Mon Sep 17 00:00:00 2001 From: "JB (Don)" <1557853+hackyon@users.noreply.github.com> Date: Wed, 25 Oct 2023 01:26:16 +0800 Subject: [PATCH 31/38] Add a default decoder_attention_mask for EncoderDecoderModel during training (#26752) * Add a default decoder_attention_mask for EncoderDecoderModel during training Since we are already creating the default decoder_input_ids from the labels, we should also create a default decoder_attention_mask to go with it. * Fix test constant that relied on manual_seed() The test was changed to use a decoder_attention_mask that ignores padding instead (which is the default one created by BERT when attention_mask is None). * Create the decoder_attention_mask using decoder_input_ids instead of labels * Fix formatting in test --- .../modeling_encoder_decoder.py | 2 + .../test_modeling_encoder_decoder.py | 54 ++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 3548e48c595a4a..787db7272642dd 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -620,6 +620,8 @@ def forward( decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id) # Decode decoder_outputs = self.decoder( diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index c476744057e89d..25444d7d32ffa6 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -17,8 +17,8 @@ import tempfile import unittest -from transformers import is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers import is_torch_available, logging +from transformers.testing_utils import CaptureLogger, require_torch, slow, torch_device from ...test_modeling_common import ids_tensor from ..bart.test_modeling_bart import BartStandaloneDecoderModelTester @@ -766,6 +766,56 @@ def test_bert2bert_summarization(self): self.assertEqual(summary, [EXPECTED_SUMMARY_SIGMA, EXPECTED_SUMMARY_AMERICA]) + def test_bert2bert_default_decoder_attention_mask(self): + torch.manual_seed(0) + test_dict = self.prepare_config_and_inputs() + encoder_config, decoder_config = test_dict["config"], test_dict["decoder_config"] + + encoder_config.pad_token_id = 5 + encoder_config.decoder_start_token_id = 2 + decoder_config.pad_token_id = 5 + decoder_config.decoder_start_token_id = 2 + + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config) + config.pad_token_id = 5 + config.decoder_start_token_id = 2 + + encoder_model, decoder_model = self.get_encoder_decoder_model(encoder_config, decoder_config) + model = EncoderDecoderModel(config=config, encoder=encoder_model, decoder=decoder_model) + + input_ids = torch.tensor( + [ + [10, 55, 89, 11, 57, 32, 36, 78, 46, 28, 5, 5, 5], + [10, 21, 97, 71, 63, 19, 12, 57, 5, 5, 5, 5, 5], + ] + ) + attention_mask = input_ids.new_tensor(input_ids != 5) + labels = torch.tensor( + [ + [33, 23, 91, 12, 19, 96, 5, 5], + [87, 85, 13, 31, 5, 5, 5, 5], + ] + ) + + logger = logging.get_logger("transformers.modeling_utils") + logger.warning_once.cache_clear() + + with CaptureLogger(logger) as cl: + torch.manual_seed(0) + output = model(input_ids, attention_mask, labels=labels) + + # Assert that the warning does not show up since a default decoder_attention_mask should have been created. + self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out) + + # Create a new attention mask that ignores padding, and test that the loss differs for this new attention mask + # and the default attention mask. + attention_mask_ignoring_padding = torch.ones(labels.shape, dtype=torch.long) + torch.manual_seed(0) + ignore_pad_tokens_output = model( + input_ids, attention_mask, labels=labels, decoder_attention_mask=attention_mask_ignoring_padding + ) + self.assertNotAlmostEqual(output.loss.item(), ignore_pad_tokens_output.loss.item()) + @require_torch class BertGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): From 6cbc1369a330860c128a1ba365f246751382c9e5 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:37:09 +0200 Subject: [PATCH 32/38] Fix RoPE config validation for FalconConfig + various config typos (#26929) * Resolve incorrect ValueError in RoPE config for Falcon * Add broken codeblock tag in Falcon Config * Fix typo: an float -> a float * Implement copy functionality for Fuyu and Persimmon for RoPE scaling validation * Make style --- .../deprecated/open_llama/configuration_open_llama.py | 6 +++--- src/transformers/models/falcon/configuration_falcon.py | 8 ++++---- src/transformers/models/fuyu/configuration_fuyu.py | 7 ++++--- .../models/gpt_neox/configuration_gpt_neox.py | 6 +++--- src/transformers/models/llama/configuration_llama.py | 6 +++--- .../models/persimmon/configuration_persimmon.py | 7 ++++--- 6 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py index 93e1394ab6d9d6..1d5756cd38a399 100644 --- a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py @@ -69,8 +69,8 @@ class OpenLlamaConfig(PretrainedConfig): Whether to tie weight embeddings rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an @@ -164,4 +164,4 @@ def _rope_scaling_validation(self): f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/falcon/configuration_falcon.py b/src/transformers/models/falcon/configuration_falcon.py index e43340275771c6..251dd7f2d571c9 100644 --- a/src/transformers/models/falcon/configuration_falcon.py +++ b/src/transformers/models/falcon/configuration_falcon.py @@ -79,8 +79,8 @@ class FalconConfig(PretrainedConfig): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an @@ -173,7 +173,7 @@ def _rope_scaling_validation(self): if self.rope_scaling is None: return - if self.rotary: + if self.alibi: raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.") if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: @@ -188,4 +188,4 @@ def _rope_scaling_validation(self): f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/fuyu/configuration_fuyu.py b/src/transformers/models/fuyu/configuration_fuyu.py index b031fff45a7e33..d3c447469c7175 100644 --- a/src/transformers/models/fuyu/configuration_fuyu.py +++ b/src/transformers/models/fuyu/configuration_fuyu.py @@ -72,8 +72,8 @@ class FuyuConfig(PretrainedConfig): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalFuyu/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an @@ -189,6 +189,7 @@ def __init__( **kwargs, ) + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -208,4 +209,4 @@ def _rope_scaling_validation(self): f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 896bda51317713..e263a8a9fee66c 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -80,8 +80,8 @@ class GPTNeoXConfig(PretrainedConfig): speedup at large scales (e.g. 20B). rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an @@ -173,4 +173,4 @@ def _rope_scaling_validation(self): f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index f3da8ab4cdc242..f37e52ca9ca539 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -87,8 +87,8 @@ class LlamaConfig(PretrainedConfig): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an @@ -184,4 +184,4 @@ def _rope_scaling_validation(self): f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/persimmon/configuration_persimmon.py b/src/transformers/models/persimmon/configuration_persimmon.py index 8606e4febffe80..c48070a9da0cbb 100644 --- a/src/transformers/models/persimmon/configuration_persimmon.py +++ b/src/transformers/models/persimmon/configuration_persimmon.py @@ -65,8 +65,8 @@ class PersimmonConfig(PretrainedConfig): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This @@ -141,6 +141,7 @@ def __init__( **kwargs, ) + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. @@ -160,4 +161,4 @@ def _rope_scaling_validation(self): f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") From 9286f0ac3939a7081773fc66480f651a7d6a8404 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 25 Oct 2023 10:47:33 +0200 Subject: [PATCH 33/38] Skip-test (#27062) * skip plbart test * nits * update --- tests/models/plbart/test_modeling_plbart.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 4cd8ecd14fe875..4fbbb6835bfa9b 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -314,6 +314,10 @@ def test_generate_fp16(self): model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + @unittest.skip("Failing since #26752") + def test_sample_generate(self): + pass + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" From 06e782da4e58f93a60c6bedc84b5991abaae58f5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:16:15 +0200 Subject: [PATCH 34/38] [`core`] Refactor of `gradient_checkpointing` (#27020) * v1 * fix * remove `create_custom_forward` * fixup * fixup * add test and fix all failing GC tests * remove all remaining `create_custom_forward` methods * fix idefics bug * fixup * replace with `__call__` * add comment * quality --- src/transformers/modeling_utils.py | 22 +++++++-- .../models/align/modeling_align.py | 20 ++++---- .../models/altclip/modeling_altclip.py | 33 +++++-------- .../modeling_audio_spectrogram_transformer.py | 17 +++---- .../models/autoformer/modeling_autoformer.py | 31 ++++-------- src/transformers/models/bark/modeling_bark.py | 19 +++----- src/transformers/models/bart/modeling_bart.py | 31 ++++-------- src/transformers/models/beit/modeling_beit.py | 17 +++---- src/transformers/models/bert/modeling_bert.py | 18 +++---- .../modeling_bert_generation.py | 18 +++---- .../models/big_bird/modeling_big_bird.py | 18 +++---- .../modeling_bigbird_pegasus.py | 31 ++++-------- .../models/biogpt/modeling_biogpt.py | 19 +++----- src/transformers/models/bit/modeling_bit.py | 5 +- .../models/blenderbot/modeling_blenderbot.py | 31 ++++-------- .../modeling_blenderbot_small.py | 31 ++++-------- src/transformers/models/blip/modeling_blip.py | 21 ++++---- .../models/blip/modeling_blip_text.py | 13 ++--- .../models/blip_2/modeling_blip_2.py | 34 +++++-------- .../models/bloom/modeling_bloom.py | 19 +++----- .../bridgetower/modeling_bridgetower.py | 13 ++--- src/transformers/models/bros/modeling_bros.py | 12 ++--- .../models/camembert/modeling_camembert.py | 18 +++---- .../models/canine/modeling_canine.py | 17 +++---- .../chinese_clip/modeling_chinese_clip.py | 30 ++++-------- src/transformers/models/clap/modeling_clap.py | 29 ++++------- src/transformers/models/clip/modeling_clip.py | 17 +++---- .../models/clipseg/modeling_clipseg.py | 17 +++---- .../models/codegen/modeling_codegen.py | 19 +++----- .../modeling_conditional_detr.py | 16 ++----- .../models/convbert/modeling_convbert.py | 17 +++---- .../models/convnext/modeling_convnext.py | 5 +- .../models/convnextv2/modeling_convnextv2.py | 5 +- .../models/cpmant/modeling_cpmant.py | 5 +- .../data2vec/modeling_data2vec_audio.py | 28 ++++------- .../models/data2vec/modeling_data2vec_text.py | 18 +++---- .../data2vec/modeling_data2vec_vision.py | 17 +++---- .../models/deberta/modeling_deberta.py | 17 +++---- .../models/deberta_v2/modeling_deberta_v2.py | 17 +++---- .../modeling_decision_transformer.py | 19 +++----- .../modeling_deformable_detr.py | 16 ++----- src/transformers/models/deit/modeling_deit.py | 17 +++---- .../models/deprecated/mctct/modeling_mctct.py | 17 +++---- .../open_llama/modeling_open_llama.py | 19 +++----- .../modeling_trajectory_transformer.py | 16 ++----- .../models/deprecated/van/modeling_van.py | 5 +- src/transformers/models/deta/modeling_deta.py | 16 ++----- src/transformers/models/detr/modeling_detr.py | 16 ++----- .../models/dinat/modeling_dinat.py | 2 +- .../models/dinov2/modeling_dinov2.py | 17 +++---- .../models/distilbert/modeling_distilbert.py | 17 +++---- .../models/donut/modeling_donut_swin.py | 16 ++----- src/transformers/models/dpr/modeling_dpr.py | 5 +- src/transformers/models/dpt/modeling_dpt.py | 17 +++---- .../efficientnet/modeling_efficientnet.py | 5 +- .../models/electra/modeling_electra.py | 18 +++---- .../models/encodec/modeling_encodec.py | 5 +- .../modeling_encoder_decoder.py | 6 +-- .../models/ernie/modeling_ernie.py | 18 +++---- .../models/ernie_m/modeling_ernie_m.py | 5 +- src/transformers/models/esm/modeling_esm.py | 18 +++---- .../models/falcon/modeling_falcon.py | 20 ++++---- .../models/flava/modeling_flava.py | 17 +++---- src/transformers/models/fnet/modeling_fnet.py | 14 ++---- .../models/focalnet/modeling_focalnet.py | 16 ++----- src/transformers/models/fuyu/modeling_fuyu.py | 5 +- src/transformers/models/git/modeling_git.py | 30 ++++-------- src/transformers/models/gpt2/modeling_gpt2.py | 20 +++----- .../gpt_bigcode/modeling_gpt_bigcode.py | 19 +++----- .../models/gpt_neo/modeling_gpt_neo.py | 19 +++----- .../models/gpt_neox/modeling_gpt_neox.py | 20 ++++---- .../modeling_gpt_neox_japanese.py | 5 +- src/transformers/models/gptj/modeling_gptj.py | 19 +++----- .../modeling_gptsan_japanese.py | 5 +- .../models/graphormer/modeling_graphormer.py | 5 +- .../models/groupvit/modeling_groupvit.py | 18 +++---- .../models/hubert/modeling_hubert.py | 42 +++++----------- .../models/idefics/modeling_idefics.py | 12 ++--- src/transformers/models/idefics/vision.py | 12 ++--- .../models/imagegpt/modeling_imagegpt.py | 19 +++----- .../models/informer/modeling_informer.py | 33 +++++-------- .../instructblip/modeling_instructblip.py | 34 +++++-------- .../models/layoutlm/modeling_layoutlm.py | 18 +++---- .../models/layoutlmv2/modeling_layoutlmv2.py | 17 +++---- .../models/layoutlmv3/modeling_layoutlmv3.py | 15 +----- src/transformers/models/led/modeling_led.py | 32 +++++-------- .../models/levit/modeling_levit.py | 5 +- src/transformers/models/lilt/modeling_lilt.py | 17 +++---- .../models/llama/modeling_llama.py | 23 +++++---- .../models/longformer/modeling_longformer.py | 18 +++---- .../models/longt5/modeling_longt5.py | 20 +++----- src/transformers/models/luke/modeling_luke.py | 17 +++---- .../models/m2m_100/modeling_m2m_100.py | 31 ++++-------- .../models/marian/modeling_marian.py | 31 ++++-------- .../models/markuplm/modeling_markuplm.py | 13 ++--- .../mask2former/modeling_mask2former.py | 12 ++--- .../models/maskformer/modeling_maskformer.py | 20 ++++---- .../maskformer/modeling_maskformer_swin.py | 19 ++++---- .../models/mbart/modeling_mbart.py | 33 +++++-------- .../megatron_bert/modeling_megatron_bert.py | 18 +++---- .../models/mgp_str/modeling_mgp_str.py | 5 +- .../models/mistral/modeling_mistral.py | 20 ++++---- .../models/mobilevit/modeling_mobilevit.py | 16 ++----- .../mobilevitv2/modeling_mobilevitv2.py | 16 ++----- src/transformers/models/mpt/modeling_mpt.py | 19 +++----- src/transformers/models/mra/modeling_mra.py | 16 ++----- src/transformers/models/mt5/modeling_mt5.py | 16 +++---- .../models/musicgen/modeling_musicgen.py | 23 ++++----- src/transformers/models/mvp/modeling_mvp.py | 31 ++++-------- src/transformers/models/nat/modeling_nat.py | 2 +- .../models/nezha/modeling_nezha.py | 18 +++---- .../models/nllb_moe/modeling_nllb_moe.py | 28 ++++------- .../nystromformer/modeling_nystromformer.py | 17 +++---- .../models/oneformer/modeling_oneformer.py | 2 +- src/transformers/models/opt/modeling_opt.py | 19 +++----- .../models/owlv2/modeling_owlv2.py | 18 +++---- .../models/owlvit/modeling_owlvit.py | 17 +++---- .../models/pegasus/modeling_pegasus.py | 31 ++++-------- .../models/pegasus_x/modeling_pegasus_x.py | 31 ++++-------- .../models/persimmon/modeling_persimmon.py | 19 +++----- .../models/pix2struct/modeling_pix2struct.py | 35 +++++--------- .../models/plbart/modeling_plbart.py | 31 ++++-------- .../models/poolformer/modeling_poolformer.py | 5 +- .../models/pop2piano/modeling_pop2piano.py | 16 +++---- .../models/prophetnet/modeling_prophetnet.py | 31 ++++-------- src/transformers/models/pvt/modeling_pvt.py | 5 +- .../models/qdqbert/modeling_qdqbert.py | 18 +++---- .../models/realm/modeling_realm.py | 13 ++--- .../models/regnet/modeling_regnet.py | 5 +- .../models/rembert/modeling_rembert.py | 18 +++---- .../models/resnet/modeling_resnet.py | 5 +- .../models/roberta/modeling_roberta.py | 18 +++---- .../modeling_roberta_prelayernorm.py | 18 +++---- .../models/roc_bert/modeling_roc_bert.py | 18 +++---- .../models/roformer/modeling_roformer.py | 18 +++---- src/transformers/models/rwkv/modeling_rwkv.py | 17 ++----- src/transformers/models/sam/modeling_sam.py | 11 +---- .../seamless_m4t/modeling_seamless_m4t.py | 42 +++++----------- src/transformers/models/sew/modeling_sew.py | 28 ++++------- .../models/sew_d/modeling_sew_d.py | 30 ++++-------- .../modeling_speech_encoder_decoder.py | 6 +-- .../speech_to_text/modeling_speech_to_text.py | 31 ++++-------- .../modeling_speech_to_text_2.py | 17 ++----- .../models/speecht5/modeling_speecht5.py | 48 +++++-------------- .../models/splinter/modeling_splinter.py | 18 +++---- .../swiftformer/modeling_swiftformer.py | 5 +- src/transformers/models/swin/modeling_swin.py | 16 ++----- .../models/swin/modeling_tf_swin.py | 5 -- .../models/swin2sr/modeling_swin2sr.py | 16 ++----- .../models/swinv2/modeling_swinv2.py | 16 ++----- .../modeling_switch_transformers.py | 16 +++---- src/transformers/models/t5/modeling_t5.py | 16 +++---- .../modeling_table_transformer.py | 16 ++----- .../models/tapas/modeling_tapas.py | 18 +++---- .../modeling_time_series_transformer.py | 31 ++++-------- .../timesformer/modeling_timesformer.py | 17 +++---- .../models/trocr/modeling_trocr.py | 19 +++----- src/transformers/models/tvlt/modeling_tvlt.py | 31 ++++-------- src/transformers/models/umt5/modeling_umt5.py | 16 +++---- .../models/unispeech/modeling_unispeech.py | 40 +++++----------- .../unispeech_sat/modeling_unispeech_sat.py | 40 +++++----------- .../models/upernet/modeling_upernet.py | 5 +- .../models/videomae/modeling_videomae.py | 31 ++++-------- src/transformers/models/vilt/modeling_vilt.py | 17 +++---- .../modeling_vision_encoder_decoder.py | 6 +-- .../visual_bert/modeling_visual_bert.py | 17 +++---- src/transformers/models/vit/modeling_vit.py | 17 +++---- .../models/vit_hybrid/modeling_vit_hybrid.py | 17 +++---- .../models/vit_mae/modeling_vit_mae.py | 31 ++++-------- .../models/vit_msn/modeling_vit_msn.py | 17 +++---- .../models/vitdet/modeling_vitdet.py | 17 +++---- .../models/vitmatte/modeling_vitmatte.py | 10 +++- src/transformers/models/vits/modeling_vits.py | 19 +++----- .../models/vivit/modeling_vivit.py | 17 +++---- .../models/wav2vec2/modeling_wav2vec2.py | 40 +++++----------- .../modeling_wav2vec2_conformer.py | 28 ++++------- .../models/wavlm/modeling_wavlm.py | 40 +++++----------- .../models/whisper/modeling_whisper.py | 31 ++++-------- .../models/x_clip/modeling_x_clip.py | 29 ++++------- src/transformers/models/xglm/modeling_xglm.py | 19 +++----- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 31 ++++-------- .../xlm_roberta/modeling_xlm_roberta.py | 18 +++---- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 13 ++--- src/transformers/models/xmod/modeling_xmod.py | 18 +++---- .../models/yolos/modeling_yolos.py | 17 +++---- src/transformers/models/yoso/modeling_yoso.py | 17 +++---- ...ng_{{cookiecutter.lowercase_modelname}}.py | 47 +++++++----------- tests/test_modeling_common.py | 21 ++++++++ 188 files changed, 1276 insertions(+), 2296 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6f866f989a6777..47e9cb2f23e0f1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import functools import gc import importlib.metadata import inspect @@ -1848,16 +1849,31 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) - def gradient_checkpointing_enable(self): + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ Activates gradient checkpointing for the current model. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. """ if not self.supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + gradient_checkpointing_func = functools.partial( + torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs + ) + + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func)) if getattr(self, "_hf_peft_config_loaded", False): # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True @@ -1874,7 +1890,7 @@ def gradient_checkpointing_disable(self): activations". """ if self.supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) + self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) if getattr(self, "_hf_peft_config_loaded", False): self.disable_input_require_grads() diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 6cbf01a3432ccb..58dc2a89200930 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1095,20 +1095,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1197,9 +1192,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (AlignTextModel, AlignVisionModel)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index c4e32de55d9c03..e6229165aace86 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -646,20 +646,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -960,18 +955,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1089,11 +1078,13 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, AltCLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, AltRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 28969f50b67291..a1f85e2a09eb12 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -336,17 +336,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -395,9 +389,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.weight.data.fill_(1.0) # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST - def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ASTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 96298c77a344e7..40e3002310852b 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -946,9 +946,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (AutoformerDecoder, AutoformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None AUTOFORMER_START_DOCSTRING = r""" @@ -1207,18 +1208,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1425,16 +1420,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1442,6 +1429,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 649719e0eefa5d..2708b00d05c49c 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -313,9 +313,10 @@ def device(self) -> torch.device: return get_parameter_device(self) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BARK_MODEL_START_DOCSTRING = """ @@ -637,20 +638,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 9e7763ca23d885..73eca72e5d1288 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -521,9 +521,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BartDecoder, BartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -854,18 +855,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1110,16 +1105,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1127,6 +1114,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d698cff88b146e..3ba3d4911b0fb5 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -510,17 +510,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( @@ -572,9 +566,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BeitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1b0fad3f9d6546..91380e13a05522 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -593,20 +593,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -762,9 +757,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index abe2d828b28bb9..123cb2212e19c6 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -401,20 +401,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -607,9 +602,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BERT_GENERATION_START_DOCSTRING = r""" diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e266b1a67b7d41..0ba2119e684492 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1617,15 +1617,8 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, @@ -1635,6 +1628,8 @@ def custom_forward(*inputs): from_mask, to_mask, blocked_encoder_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1784,9 +1779,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BigBirdEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIG_BIRD_START_DOCSTRING = r""" diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 4e279f9dc059fe..98ff51032bad5e 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1609,9 +1609,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1943,15 +1944,8 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), @@ -1960,6 +1954,7 @@ def custom_forward(*inputs): to_mask, blocked_encoder_mask, blocked_encoder_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2289,16 +2284,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2306,6 +2293,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index ca084db5c7d0b9..2bbdbed348a196 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -376,9 +376,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BioGptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIOGPT_START_DOCSTRING = r""" @@ -590,20 +591,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 12a5ecd42b74cf..d02861d6343da5 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -669,9 +669,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1db81905210b63..51a947af0a8324 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -483,9 +483,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -777,18 +778,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1032,16 +1027,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1049,6 +1036,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 129de3dd1456e3..88a9b52de90966 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -480,9 +480,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -775,18 +776,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1029,16 +1024,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1046,6 +1033,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 9fca7c28a1a07d..efd986299c292b 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -34,7 +34,7 @@ replace_return_docstrings, ) from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig -from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel +from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel logger = logging.get_logger(__name__) @@ -461,9 +461,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BlipEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (BlipEncoder, BlipTextEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None BLIP_START_DOCSTRING = r""" @@ -622,17 +623,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 49b958afc2ebae..e0aa4e17f146fe 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -422,20 +422,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index bd56b17e55c21b..2f7f00b3dd5970 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -297,9 +297,14 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, Blip2Encoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) BLIP_2_START_DOCSTRING = r""" @@ -473,17 +478,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -944,15 +943,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d90bb6ad8fdfd5..583367c9ab5571 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -496,9 +496,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, BloomModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_standard_cache( @@ -761,21 +762,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index ce569157b811c2..0f272a21e21da7 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -804,20 +804,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index a8ea8d49195b88..c10f8350567611 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -651,21 +651,15 @@ def forward( "`use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, bbox_pos_emb, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 8d7d279579e3e9..2e0a6c12fe6463 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -524,20 +524,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -625,9 +620,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CamembertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CAMEMBERT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 657104ad696535..adc875910320eb 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -795,18 +795,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -919,9 +913,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CanineEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CANINE_START_DOCSTRING = r""" diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 7bab0aea6eb95d..ef1c265723b6f0 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -742,9 +742,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CHINESE_CLIP_START_DOCSTRING = r""" @@ -909,20 +910,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1018,16 +1014,10 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 1d17a518838734..025b59ae4b9743 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -939,15 +939,8 @@ def forward( input_dimensions = self.input_resolutions[i] if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -1595,20 +1588,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -1701,9 +1689,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ClapTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class ClapAudioModel(ClapPreTrainedModel): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 3a894b9727c92b..56f24c157f831c 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -467,9 +467,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIP_START_DOCSTRING = r""" @@ -639,18 +640,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 96f13217aaf821..7a0e5292698393 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -479,9 +479,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CLIPSegEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CLIPSEG_START_DOCSTRING = r""" @@ -648,18 +649,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 05699ef15c722f..9a5509a9ed866e 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -339,9 +339,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CodeGenModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CODEGEN_START_DOCSTRING = r""" @@ -542,21 +543,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 69937afefce453..01dbf8ecd59cc7 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1171,9 +1171,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConditionalDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONDITIONAL_DETR_START_DOCSTRING = r""" @@ -1518,15 +1519,8 @@ def forward( # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, object_queries, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index a6fccf5b72b443..da577a58961430 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -264,9 +264,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SeparableConv1D(nn.Module): @@ -632,20 +633,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index e6cf336517a563..e11112b5322266 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -296,9 +296,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index 3a268c713d502a..f1ff89bb124398 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -317,9 +317,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ConvNextV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CONVNEXTV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 6d2dc596fa65ff..8a6c744ed69ecc 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -556,9 +556,10 @@ def _init_weights(self, module): elif isinstance(module, CpmAntSegmentPositionEmbedding): module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, CpmAntEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None CPMANT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 4435e9b8d01754..a99b6f3a6dc384 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -293,15 +293,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -593,17 +586,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -761,9 +748,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_AUDIO_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a521ccb39aaf0c..507c2fc464d8b7 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -510,20 +510,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -613,9 +608,10 @@ def _init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecTextEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VECTEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index f8fe59587af0cc..2742d5ffc37bbe 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -522,17 +522,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: relative_position_bias = ( @@ -585,9 +579,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Data2VecVisionEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DATA2VEC_VISION_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 6f6c2af63a672e..65ec497cecd8c0 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -457,20 +457,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: hidden_states = layer_module( @@ -839,9 +833,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index eda4f406cb316d..2245ac549ada94 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -501,20 +501,14 @@ def forward( all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - output_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + output_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( @@ -938,9 +932,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DebertaV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 8e5053a4160d12..19c2731a50a745 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -469,9 +469,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DecisionTransformerGPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): @@ -631,22 +632,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index ea4555d5ae217b..220fcf0d066003 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1088,9 +1088,10 @@ def _init_weights(self, module): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DeformableDetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEFORMABLE_DETR_START_DOCSTRING = r""" @@ -1383,15 +1384,8 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 38c28dbbedc669..6e97e932b533a7 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -357,17 +357,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -415,9 +409,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, DeiTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index eca5ba014e51a6..9e7a73c5880be7 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -504,9 +504,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_ma attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MCTCTEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MCTCT_START_DOCSTRING = r""" @@ -616,18 +617,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 3ace323e822414..fb1cc7f0fb84d5 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -456,9 +456,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OpenLlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPEN_LLAMA_INPUTS_DOCSTRING = r""" @@ -665,20 +666,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, None, + output_attentions, + None, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 75415dbe77bf07..c9f31c714446e6 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -163,9 +163,10 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): main_input_name = "trajectories" supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrajectoryTransformerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): @@ -550,15 +551,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, layer_past, use_cache, diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index 4ef18f54158f91..52c9e1242422c6 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -387,9 +387,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VanModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VAN_START_DOCSTRING = r""" diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 1aab38c2891384..a6f979eaeea6e1 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -979,9 +979,10 @@ def _init_weights(self, module): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetaDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETA_START_DOCSTRING = r""" @@ -1275,15 +1276,8 @@ def forward( all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index d2b6ea07d7b7ca..1c09e3e3d7b213 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -927,9 +927,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DETR_START_DOCSTRING = r""" @@ -1253,15 +1254,8 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 89c6ed2e2a88e9..eb4d3f2ff29646 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -660,7 +660,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DinatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 6e4446faddd5e9..1440b6d615fb4a 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -447,17 +447,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -516,9 +510,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: Dinov2Encoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None: if isinstance(module, Dinov2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DINOV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index f26b5846972d64..3768dd6e91ca7e 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -358,18 +358,12 @@ def forward( all_hidden_states = all_hidden_states + (hidden_state,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_state, attn_mask, head_mask[i], + output_attentions, ) else: layer_outputs = layer_module( @@ -430,9 +424,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Transformer): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DISTILBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 0d833406e259e6..76d525717f8cf3 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -749,15 +749,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -826,9 +819,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DonutSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 944ce142b0ad02..c258343f6cfdc8 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -164,9 +164,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class DPREncoder(DPRPreTrainedModel): diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 187a6c36656a8e..2621fa338015a8 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -528,17 +528,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -818,9 +812,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, DPTViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None DPT_START_DOCSTRING = r""" diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index 478aeecee02bc1..d1b2c99403437b 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -500,9 +500,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EfficientNetBlock): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index da3ee8e51d3602..fde5632c09c3f2 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -571,20 +571,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -692,9 +687,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ElectraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 697fb3c94fbb1d..28c20da3d5eb15 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -473,9 +473,10 @@ def _init_weights(self, module): elif "bias" in name: nn.init.constant_(param, 0.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (EncodecEncoder, EncodecDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ENCODEC_START_DOCSTRING = r""" diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 787db7272642dd..a13fd19a900c4f 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -265,10 +265,10 @@ def tie_weights(self): self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index d55155f80093bc..330cb5033160e5 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -506,20 +506,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -680,9 +675,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/ernie_m/modeling_ernie_m.py b/src/transformers/models/ernie_m/modeling_ernie_m.py index 9c53ddd73c8540..b26ee0fcafd19f 100755 --- a/src/transformers/models/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/ernie_m/modeling_ernie_m.py @@ -429,9 +429,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ErnieMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ERNIE_M_START_DOCSTRING = r""" diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 7a07495ba7e501..86bd20a46480d0 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -605,20 +605,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -710,9 +705,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, EsmEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ESM_START_DOCSTRING = r""" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 6eaeed4199771c..642e60a72f91df 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1097,9 +1097,10 @@ def _init_weights(self, module: nn.Module): module.weight.data.fill_(1.0) # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, FalconModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_cache_to_standard_format( @@ -1278,21 +1279,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, attention_mask, position_ids, head_mask[i], + layer_past, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 8de647c8299a09..1fbf49f9e127e6 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -663,18 +663,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -879,9 +873,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: FlavaEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, FlavaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 45042147761d56..b84761536bac46 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -292,14 +292,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states) + layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states) else: layer_outputs = layer_module(hidden_states) @@ -431,9 +424,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 8d18a8c63fda1b..87ec98169626b8 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -586,15 +586,8 @@ def forward( for i, stage_module in enumerate(self.stages): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - stage_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(stage_module), + stage_outputs = self.gradient_checkpointing_func( + stage_module.__call__, hidden_states, input_dimensions, ) @@ -659,9 +652,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FocalNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FOCALNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 03312420ca6e6d..37f9890ee3dd03 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -70,9 +70,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, FuyuForCausalLM): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None FUYU_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 00707e42dd085a..293b9c789d56b9 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -452,18 +452,13 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -533,9 +528,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GitEncoder, GitVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GIT_START_DOCSTRING = r""" @@ -878,18 +874,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 838e7ca2992520..24826a76bc0446 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -480,9 +480,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPT2Model): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -877,22 +878,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( @@ -1623,7 +1618,6 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index be90f61e45bf1b..37c51b40c9a78d 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -405,9 +405,10 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTBigCodeModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_BIGCODE_START_DOCSTRING = r""" @@ -650,22 +651,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3ad49554c0ac8f..ed1e62bf175fb1 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -384,9 +384,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPT_NEO_START_DOCSTRING = r""" @@ -604,20 +605,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9391805a77b851..cf0aa0645ae07f 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -78,9 +78,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXAttention(nn.Module): @@ -641,20 +642,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for layer_past - return module(*inputs, use_cache, None, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_ids, head_mask[i], + use_cache, + None, + output_attentions, ) else: outputs = layer( diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 98753edeb544f8..c1c5527a465531 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -66,9 +66,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTNeoXJapaneseModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GPTNeoXJapaneseAttention(nn.Module): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index acdbb8c49211d5..2910f9535f66ea 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -363,9 +363,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GPTJModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GPTJ_START_DOCSTRING = r""" @@ -669,21 +670,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, position_ids, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 24917fcfdb075d..84d956c9f57e60 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -759,9 +759,10 @@ def _init_weights(self, module): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GPTSanJapaneseAttention,)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right def _shift_right(self, input_ids): diff --git a/src/transformers/models/graphormer/modeling_graphormer.py b/src/transformers/models/graphormer/modeling_graphormer.py index 8247745a3bc3ef..68ed6d265e70ae 100755 --- a/src/transformers/models/graphormer/modeling_graphormer.py +++ b/src/transformers/models/graphormer/modeling_graphormer.py @@ -772,9 +772,10 @@ def _init_weights( module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, GraphormerModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class GraphormerModel(GraphormerPreTrainedModel): diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 59ff60ed765a51..a9de67143846b7 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -492,7 +492,6 @@ def __init__( self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size)) else: self.group_token = None - self.gradient_checkpointing = False self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)]) if num_group_token > 0: @@ -805,9 +804,10 @@ def _init_weights(self, module): nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc2.weight, std=in_proj_std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None GROUPVIT_START_DOCSTRING = r""" @@ -1031,18 +1031,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 1a7bde45efc128..732e6be2f8ddd1 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -346,15 +346,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -731,17 +724,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -821,17 +808,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -895,9 +876,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 316f36561308f0..28841903a1a3bb 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -40,7 +40,7 @@ ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionTransformer +from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer logger = logging.get_logger(__name__) @@ -978,9 +978,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, IdeficsModel): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -1098,7 +1099,6 @@ def __init__(self, config: IdeficsConfig): self.norm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1339,7 +1339,7 @@ def vblock( ) use_cache = False - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = self.gradient_checkpointing_func( vblock, decoder_layer, hidden_states, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index d4966a240d84eb..24dc3e9396aa79 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -401,18 +401,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 54edcd30fc870d..a365731ed53d0e 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -525,9 +525,10 @@ def _init_weights(self, module): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ImageGPTModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None IMAGEGPT_START_DOCSTRING = r""" @@ -816,22 +817,16 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index e7b35174ca7e60..53518760cc003b 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -924,9 +924,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (InformerDecoder, InformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None INFORMER_START_DOCSTRING = r""" @@ -1215,21 +1216,15 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) if conv_layer is not None: - output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) + output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) layer_outputs = (output,) + layer_outputs[1:] else: layer_outputs = encoder_layer( @@ -1438,16 +1433,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1455,6 +1442,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 082900a6652f80..d4cb7a1fa00bcf 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -304,9 +304,14 @@ def _init_weights(self, module): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, InstructBlipEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + # Enable / disable GC for the language model as well + if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): + self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) INSTRUCTBLIP_START_DOCSTRING = r""" @@ -462,17 +467,11 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -939,15 +938,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, query_length) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 884a2799728b47..ce6d4302bccc19 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -487,20 +487,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -638,9 +633,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LAYOUTLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index ef970edfdc9103..8f6260fdda4931 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -439,18 +439,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, rel_pos=rel_pos, rel_2d_pos=rel_2d_pos, ) @@ -514,9 +508,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LayoutLMv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def my_convert_sync_batchnorm(module, process_group=None): diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 30ab0a5e8620c3..e387707e52da4f 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -657,19 +657,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos) - # The above line will cause error: - # RuntimeError: Trying to backward through the graph a second time - # (or directly access saved tensors after they have already been freed). - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index f0c22ed9502c26..61bbd4156b4603 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1155,9 +1155,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (LEDDecoder, LEDEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -1876,20 +1877,15 @@ def forward( layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2142,16 +2138,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -2159,6 +2147,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 0accc28391bde6..5acaaeba90048a 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -507,9 +507,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LevitModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 46fe2d3e9cd779..4fd7a85affd76c 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -514,19 +514,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layout_inputs, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( @@ -607,9 +601,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LiltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 541455d86afde8..279884dc164f04 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -827,9 +827,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LlamaModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LLAMA_INPUTS_DOCSTRING = r""" @@ -1013,16 +1014,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 33bf9a6f92684c..b4f20b4525585e 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1304,20 +1304,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, is_global_attn, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, is_index_masked, is_index_global_attn, + is_global_attn, + output_attentions, ) else: layer_outputs = layer_module( @@ -1439,9 +1434,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LongformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index c80d2349832c43..9abbfa2f2001f6 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -775,7 +775,6 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = Fal if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() - self.gradient_checkpointing = False # Relativen attention bias & Layer norm for global attention if self.has_relative_attention_bias: @@ -1340,10 +1339,10 @@ def _init_weights(self, module): mean=0.0, std=factor * ((d_model) ** -0.5) ) - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LongT5Attention, LongT5Stack)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 def _shift_right(self, input_ids): @@ -1510,15 +1509,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1528,6 +1520,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 6913ede09d1c7b..3b5f4d0bf71dc0 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -788,19 +788,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, word_hidden_states, entity_hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module( @@ -920,9 +914,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, LukeEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 6db8bbb5213b14..4ebe11f3f3b3bc 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -552,9 +552,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (M2M100Decoder, M2M100Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None M2M_100_START_DOCSTRING = r""" @@ -820,18 +821,12 @@ def forward( # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1066,16 +1061,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -1083,6 +1070,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 69de5b2e7d0e6f..e2e09b564b0e48 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -500,9 +500,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalP if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MarianDecoder, MarianEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -788,18 +789,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1037,16 +1032,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1054,6 +1041,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 530c66a0c80b36..80498efb3cadd2 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -648,20 +648,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index e839b16f625777..86eccc478753f3 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1864,20 +1864,14 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, None, None, + output_attentions, ) else: diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 87b91ed64b62d3..7df8b60792a054 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -848,20 +848,14 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, encoder_attention_mask, None, + output_attentions, ) else: layer_outputs = decoder_layer( @@ -1619,11 +1613,13 @@ def _init_weights(self, module: nn.Module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerPixelLevelModule): - module.encoder.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None if isinstance(module, DetrDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 357ac9d4aaca36..89c6a0c0e0b4c1 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -688,15 +688,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, layer_head_mask + layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, ) else: layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( @@ -752,9 +748,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MaskFormerSwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b53ad8848dd3c2..7c4c9bdf959801 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -516,9 +516,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (MBartDecoder, MBartDecoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (MBartDecoder, MBartEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -828,18 +829,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1086,16 +1081,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1103,6 +1090,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 5d0ad6e3410c8f..c23666f10b725f 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -551,20 +551,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -728,9 +723,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MegatronBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index 5d1f5bea7bfd35..1257b4df39c015 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -333,9 +333,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: MgpstrEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: MgpstrEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, MgpstrEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MGP_STR_START_DOCSTRING = r""" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 53667b6a82c327..fbedb20dbc1539 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -816,9 +816,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MistralModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MISTRAL_INPUTS_DOCSTRING = r""" @@ -1020,19 +1021,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index c3accb21e05e42..c664c02a883ba0 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -626,15 +626,8 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: @@ -672,9 +665,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 5a0e08d7344dc7..b88925f41b83c6 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -582,15 +582,8 @@ def forward( for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + hidden_states = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: @@ -629,9 +622,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MobileViTV2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MOBILEVITV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index d760bec9854a8e..ede306e71b867e 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -294,9 +294,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): + def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): if isinstance(module, MptModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @staticmethod def _convert_to_mpt_cache( @@ -523,20 +524,14 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self.gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, + use_cache, + output_attentions, ) else: outputs = block( diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index d400fea6d23dda..f6cb65889a371c 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -766,15 +766,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, ) @@ -871,9 +864,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MraEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MRA_START_DOCSTRING = r""" diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 186db94dad7f20..2951ffc889dcdb 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -845,9 +845,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MT5Attention, MT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1073,15 +1074,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1091,6 +1085,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index bcc6bc82a2f5f4..a740ed47074bc2 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -475,9 +475,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, MusicgenDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None MUSICGEN_START_DOCSTRING = r""" @@ -826,16 +827,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, attention_mask, encoder_hidden_states, @@ -843,6 +836,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( @@ -1562,10 +1557,10 @@ def tie_weights(self): self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.text_encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.text_encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_audio_encoder(self): return self.audio_encoder diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 5c1ed05249ef5c..122b49287872ae 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -563,9 +563,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (MvpDecoder, MvpEncoder, MvpPrompt)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @property def dummy_inputs(self): @@ -949,19 +950,13 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), (self_attn_prompt[idx] if self.use_prompt else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1227,16 +1222,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1246,6 +1233,8 @@ def custom_forward(*inputs): self_attn_prompt[idx] if self.use_prompt else None, cross_attn_prompt[idx] if self.use_prompt else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index ecc745b558dd71..4f7206a5e8ecf5 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -639,7 +639,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: NatEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: NatEncoder, gradient_checkpointing_func=None) -> None: pass diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index fa31e94f4d2e6b..cd43688e3f741a 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -577,20 +577,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -752,9 +747,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NezhaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index a88d53a340f4f6..cbed1e1b153011 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -874,9 +874,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (NllbMoeDecoder, NllbMoeEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NLLB_MOE_START_DOCSTRING = r""" @@ -1153,18 +1154,12 @@ def forward( layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1426,15 +1421,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(decoder_layer), + decoder_layer.forward, hidden_states, combined_attention_mask, encoder_hidden_states, @@ -1442,6 +1430,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 51ee73ab72d317..9b2052eb6ca4ae 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -370,17 +370,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) @@ -477,9 +471,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, NystromformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None NYSTROMFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 5b6220f8816949..165684542d859d 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2616,7 +2616,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor): for layer in self.layers: if self.use_checkpoint: - hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) + hidden_states = self.gradient_checkpointing_func(layer, hidden_states) else: hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f3f246524348d..9925e7b4a46b4a 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -411,9 +411,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (OPTDecoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OPT_INPUTS_DOCSTRING = r""" @@ -691,20 +692,14 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 451cc4a69126a5..a1491d15ea5527 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -584,9 +584,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Owlv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLV2_START_DOCSTRING = r""" @@ -764,18 +765,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1378,6 +1373,7 @@ def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor: """Predicts the probability that each image feature token is an object. + Args: image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)): Features extracted from the image. diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 66cfb8092df5df..68037d13950ed6 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -576,9 +576,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, OwlViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None OWLVIT_START_DOCSTRING = r""" @@ -753,18 +754,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 55856f7b06b6be..058ecd1775a9bf 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -500,9 +500,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusDecoder, PegasusEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_START_DOCSTRING = r""" @@ -803,18 +804,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1087,16 +1082,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1104,6 +1091,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index e87e9c7164ab44..6eaddf642a8b6b 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -780,9 +780,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PEGASUS_X_START_DOCSTRING = r""" @@ -1071,18 +1072,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, global_hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1331,21 +1326,15 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index d73cc4484484fa..8043fc8699a655 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -467,9 +467,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PersimmonModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PERSIMMON_INPUTS_DOCSTRING = r""" @@ -668,19 +669,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 58041820c1fb83..cfc2b137c579cf 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -343,18 +343,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -563,9 +557,10 @@ def __init__(self, config: Pix2StructConfig): # Initialize weights and apply final processing self.post_init() - def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, value: bool = False) -> None: - if isinstance(module, Pix2StructVisionEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, gradient_checkpointing_func=None) -> None: + if isinstance(module, (Pix2StructVisionEncoder, Pix2StructVisionAttention)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def get_input_embeddings(self): return self.embeddings.patch_projection @@ -1320,9 +1315,10 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pix2StructTextAttention, Pix2StructTextModel)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def __init__(self, config): super().__init__(config) @@ -1495,15 +1491,8 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1513,6 +1502,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 3a880839236d43..1e047fd372676e 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -517,9 +517,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (PLBartDecoder, PLBartEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PLBART_START_DOCSTRING = r""" @@ -807,18 +808,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1064,16 +1059,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1081,6 +1068,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 6acc8ec98e6939..209533e31990e2 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -282,9 +282,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, PoolFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None POOLFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5a67b8044b0999..5cf7039e9f0c2e 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -739,9 +739,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Pop2PianoAttention, Pop2PianoStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -902,15 +903,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -920,6 +914,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 241a9efea36aaf..e4c28659cb489b 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -557,9 +557,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1329,18 +1330,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1569,16 +1564,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -1590,6 +1577,8 @@ def custom_forward(*inputs): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 2dd452ec1df153..356b7c14afa8e0 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -489,9 +489,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ) - def _set_gradient_checkpointing(self, module: PvtEncoder, value: bool = False): + def _set_gradient_checkpointing(self, module: PvtEncoder, gradient_checkpointing_func=None): if isinstance(module, PvtEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None PVT_START_DOCSTRING = r""" diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index fead8fc0cf7f42..0a2546a9b64e86 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -581,20 +581,15 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -757,9 +752,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, QDQBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None QDQBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index aa738d782b7b6d..86b37b21560ba4 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -586,20 +586,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 07ef29fd33320b..21050f07fda441 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -293,9 +293,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RegNetModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REGNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 235bff89f8a354..e5e662a9b5564d 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -543,20 +543,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -673,9 +668,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RemBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None REMBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index f2d207c2189f27..e6b1d85b2a46e8 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -283,9 +283,10 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ResNetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None RESNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 6d4cc991d22ca0..32a19c08831777 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -510,20 +510,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -612,9 +607,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index da1cd6331bc314..78ca20684540bd 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -512,20 +512,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -615,9 +610,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RobertaPreLayerNormEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index a5b1b63050b1ef..3a58efa9140c00 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -644,20 +644,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -796,9 +791,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoCBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROC_BERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b9c36a305ff1cd..3893e27b028f2e 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -578,21 +578,16 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, sinusoidal_pos, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -715,9 +710,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RoFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None ROFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index db41bd3c9538c0..275233321372a2 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -466,9 +466,10 @@ def _init_weights(self, module): module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, RwkvModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass @@ -676,16 +677,8 @@ def forward( all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - hidden_states, state, attentions = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), hidden_states, state + hidden_states, state, attentions = self.gradient_checkpointing_func( + block.__call__, hidden_states, state, use_cache, output_attentions ) else: hidden_states, state, attentions = block( diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index abf5544a5b4de6..1bd6fcdc2a8f14 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1042,15 +1042,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, ) else: diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 91ec6a8f9b8795..ea79c734188391 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -892,15 +892,8 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, @@ -1547,9 +1540,10 @@ def _init_weights(self, module): k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TSpeechEncoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TConformerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride @@ -1856,18 +1850,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + encoder_layer.forward, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2130,16 +2118,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2147,6 +2127,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 34f9c84235cc08..36416c168c3600 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -360,15 +360,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -673,17 +666,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -756,9 +743,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SEWEncoder, SEWFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 661a8c03b1a5d9..39c9641b9489fe 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -453,15 +453,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -1134,20 +1127,14 @@ def forward( all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - output_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + output_states = self.gradient_checkpointing_func( + layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, + output_attentions, ) else: output_states = layer_module( @@ -1322,9 +1309,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, SEWDTransformerEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (SEWDEncoder, SEWDFeatureEncoder, SEWDTransformerEncoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SEWD_START_DOCSTRING = r""" diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index e80c26e2698d73..ec255fab9bc766 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -249,10 +249,10 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 31c9b6cfe93552..73a02fe66df7ba 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -559,9 +559,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -817,18 +818,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1065,16 +1060,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1082,6 +1069,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index f9b5dec4209273..acee2b15a44f2c 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -437,9 +437,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Speech2Text2Decoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPEECH_TO_TEXT_2_START_DOCSTRING = r""" @@ -669,16 +670,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 9b8ab3d3805a05..b8fea796647b70 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -520,15 +520,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -1281,9 +1274,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SpeechT5Encoder, SpeechT5Decoder, SpeechT5FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None class SpeechT5Encoder(SpeechT5PreTrainedModel): @@ -1386,19 +1380,13 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), position_bias, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1439,7 +1427,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5SpeechEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1476,7 +1463,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5TextEncoderPrenet(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1519,7 +1505,6 @@ class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_encoder = SpeechT5Encoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1715,16 +1700,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1732,6 +1709,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( @@ -1788,7 +1767,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5SpeechDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1836,7 +1814,6 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.prenet = SpeechT5TextDecoderPrenet(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1889,7 +1866,6 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): def __init__(self, config: SpeechT5Config): super().__init__(config) self.wrapped_decoder = SpeechT5Decoder(config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index f72ffb10111bc7..1bdf8f3f5f9194 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -459,20 +459,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -544,9 +539,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SplinterEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SPLINTER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index ff72f87506d36a..4170ce153bbfeb 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -442,9 +442,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) - def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, SwiftFormerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIFTFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 45a7aa718cf026..c2f15dbbf27394 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -825,15 +825,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -901,9 +894,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, SwinEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 02ec39edb0fe14..5d53561442457f 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -951,11 +951,6 @@ class TFSwinPreTrainedModel(TFPreTrainedModel): config_class = SwinConfig base_model_prefix = "swin" main_input_name = "pixel_values" - supports_gradient_checkpointing = True - - def _set_gradient_checkpointing(self, module, value=False) -> None: - if isinstance(module, TFSwinEncoder): - module.gradient_checkpointing = value SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index a8a17bdf584b00..47ce01d1691668 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -746,15 +746,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) @@ -802,9 +795,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swin2SREncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWIN2SR_START_DOCSTRING = r""" diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index a4224e16df3c25..6daad938a623ab 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -906,15 +906,8 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: layer_outputs = layer_module( @@ -983,9 +976,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, Swinv2Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None SWINV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0a402ea2d6af87..32d030728de579 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -865,9 +865,10 @@ def _init_weights(self, module): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1039,15 +1040,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1057,6 +1051,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0e7237ea36b647..c796a9cf24cfb0 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -873,9 +873,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (T5Attention, T5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1100,15 +1101,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, position_bias, @@ -1118,6 +1112,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index b6012700ee7de0..e1da557b0017cc 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -837,9 +837,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TableTransformerDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TABLE_TRANSFORMER_START_DOCSTRING = r""" @@ -1149,15 +1150,8 @@ def forward( continue if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, combined_attention_mask, encoder_hidden_states, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index cdaa4b3e2725f7..de05d77ec94358 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -646,20 +646,15 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_values, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_values, + output_attentions, ) else: layer_outputs = layer_module( @@ -778,9 +773,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TapasEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TAPAS_START_DOCSTRING = r""" diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 2caca5bd105131..1fa6a963f58f50 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -663,9 +663,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (TimeSeriesTransformerDecoder, TimeSeriesTransformerEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIME_SERIES_TRANSFORMER_START_DOCSTRING = r""" @@ -946,18 +947,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1163,16 +1158,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1180,6 +1167,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 676bcf7a5e27a0..044705c35e5410 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -439,16 +439,10 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions) @@ -494,9 +488,10 @@ def _init_weights(self, module): nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) module.patch_embeddings.apply(self._init_weights) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TimesformerEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TIMESFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index c0541814be466e..ada8638a03b608 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -454,9 +454,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, TrOCRDecoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TROCR_START_DOCSTRING = r""" @@ -701,16 +702,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -718,6 +711,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 464c3e76a11f94..a37265f37c7ad8 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -560,18 +560,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -616,9 +610,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, TvltEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (TvltEncoder, TvltDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None TVLT_START_DOCSTRING = r""" @@ -877,17 +872,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index ffafd158114074..a5b58444fe4edf 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -556,9 +556,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UMT5Attention, UMT5Stack)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -709,15 +710,8 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return tuple(module(*inputs, use_cache, output_attentions)) - - return custom_forward - layer_outputs = checkpoint( - create_custom_forward(layer_module), + layer_module.forward, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -725,6 +719,8 @@ def custom_forward(*inputs): layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index c475ab7f80f87d..db14d5bca51f5c 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -384,15 +384,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -767,17 +760,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -857,17 +844,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1039,9 +1020,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechEncoder, UniSpeechEncoderStableLayerNorm, UniSpeechFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_START_DOCSTRING = r""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 3fcc9549bbdcd6..8a9a63804b56a4 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -398,15 +398,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -781,17 +774,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -871,17 +858,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1053,9 +1034,10 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UNISPEECH_SAT_START_DOCSTRING = r""" diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index b56b508d14ae63..04b8c94e1351bd 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -315,9 +315,10 @@ def init_weights(self): if self.auxiliary_head is not None: self.auxiliary_head.init_weights() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None UPERNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 07c32d14929037..277280954fd6f1 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -434,17 +434,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -489,9 +483,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, VideoMAEEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (VideoMAEEncoder, VideoMAEDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIDEOMAE_START_DOCSTRING = r""" @@ -726,17 +721,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index a36d58bd235bb5..482bd08359bd4a 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -531,18 +531,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -591,9 +585,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ViltEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VILT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index d3e464cbfffa08..84275cc33a767b 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -225,10 +225,10 @@ def __init__( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, value=value) - self.decoder._set_gradient_checkpointing(module, value=value) + self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) + self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) def get_encoder(self): return self.encoder diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 81ad1068483a80..425a125a0b89dd 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -418,18 +418,12 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) @@ -547,9 +541,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VisualBertEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @dataclass diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 8fdacdddf04cba..67dbddf8766a41 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -397,17 +397,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -467,9 +461,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 008f6b3c9db536..959522843f7a67 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -415,17 +415,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -486,9 +480,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTHybridEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTHybridEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index ef0c7c9f36869e..e156fdc3292c4b 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -536,17 +536,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -591,9 +585,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, ViTMAEEncoder): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, (ViTMAEEncoder, ViTMAEDecoder)): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MAE_START_DOCSTRING = r""" @@ -793,17 +788,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, None, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 46639e7d622cb7..b727c331cfb4d7 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -387,17 +387,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -444,9 +438,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: ViTMSNEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, ViTMSNEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIT_MSN_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index e89fdbd7a33631..9bb3991fabf186 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -565,17 +565,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -666,9 +660,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.norm3.weight.data.zero_() module.norm3.bias.data.zero_() - def _set_gradient_checkpointing(self, module: VitDetEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: VitDetEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, VitDetEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITDET_START_DOCSTRING = r""" diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index b23bdd21d56b85..f5025a37e71c15 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -86,9 +86,15 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, BackboneMixin): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None + + for backbone_module in module.modules(): + if hasattr(backbone_module, "gradient_checkpointing"): + backbone_module.gradient_checkpointing_func = gradient_checkpointing_func + backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None class VitMatteBasicConv3x3(nn.Module): diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 49b9a1f1ae1551..b621bde35e61da 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1167,18 +1167,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, padding_mask, attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1296,9 +1290,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (VitsTextEncoder)): - module.gradient_checkpointing = value + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): + if isinstance(module, VitsEncoder): + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VITS_START_DOCSTRING = r""" diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index fd35668572a776..50cb82fb4e188f 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -338,17 +338,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -414,9 +408,10 @@ def _init_weights(self, module): elif isinstance(module, nn.Parameter): module.data.normal_(mean=0.0, std=self.config.initializer_range) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, VivitEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None VIVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index a6e02a0476f15d..9f48e529627e8e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -451,15 +451,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -803,17 +796,11 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -892,17 +879,11 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer( @@ -1173,9 +1154,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_adapters(self): if self.config.adapter_attn_dim is None: diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index f162c514297067..5fba773ee0cb4f 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -518,15 +518,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -911,18 +904,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, relative_position_embeddings, + output_attentions, ) else: layer_outputs = layer( @@ -1178,9 +1165,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAV2VEC2_CONFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 5013837cbdcefd..55b19e4c414341 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -354,15 +354,8 @@ def forward(self, input_values): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(conv_layer), + hidden_states = self.gradient_checkpointing_func( + conv_layer.__call__, hidden_states, ) else: @@ -713,18 +706,12 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( @@ -804,18 +791,12 @@ def forward( # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), + layer_outputs = self.gradient_checkpointing_func( + layer.__call__, hidden_states, attention_mask, position_bias, + output_attentions, ) else: layer_outputs = layer( @@ -1052,9 +1033,10 @@ def _get_feature_vector_attention_mask( attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None WAVLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8962324471cafc..d6d0302727cb09 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -685,9 +685,10 @@ def _init_weights(self, module): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (WhisperDecoder, WhisperEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -942,18 +943,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1174,16 +1169,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -1191,6 +1178,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, # past_key_value + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index da7eddff8df838..6c9cc02db9c83f 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -534,9 +534,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None X_CLIP_START_DOCSTRING = r""" @@ -703,18 +704,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( @@ -950,18 +945,12 @@ def forward( if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, + output_attentions, ) else: layer_outputs = encoder_layer( diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 0c769dbbb5f324..1880a7832193a8 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -503,9 +503,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XGLMModel): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None @add_start_docstrings( @@ -674,16 +675,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -691,6 +684,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index cde05cfe8a8a68..9a9f02b74a65a4 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -570,9 +570,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, (XLMProphetNetDecoder, XLMProphetNetEncoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1349,18 +1350,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, extended_attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -1592,16 +1587,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, extended_attention_mask, encoder_hidden_states, @@ -1613,6 +1600,8 @@ def custom_forward(*inputs): predict_relative_position_buckets, position_ids, None, + use_cache, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index da454b1e3331f9..da99b2806fb6f8 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -511,20 +511,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -614,9 +609,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XLMRobertaEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None XLM_ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 26e0361abdb523..49f7c07517211c 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -499,20 +499,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 28fddc2fdbd6b5..5f7b42f266fb1e 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -573,21 +573,16 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, lang_ids, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -680,9 +675,10 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel._set_gradient_checkpointing with Roberta->Xmod - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, XmodEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None def set_default_language(self, language: str): """ diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index e3cb02ceae6ec0..f6cbaecd014e61 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -492,17 +492,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, layer_head_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) @@ -551,9 +545,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: YolosEncoder, value: bool = False) -> None: + def _set_gradient_checkpointing(self, module: YolosEncoder, gradient_checkpointing_func=None) -> None: if isinstance(module, YolosEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOLOS_START_DOCSTRING = r""" diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 5edd7f8835422a..8db66d22106160 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -561,17 +561,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, + output_attentions, ) else: layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) @@ -668,9 +662,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, YosoEncoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None YOSO_START_DOCSTRING = r""" diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 02fcb7d2f511e1..0b5af845c9aaa3 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -544,19 +544,15 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), + layer_outputs = self.gradient_checkpointing_func( + layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, + output_attentions, ) else: layer_outputs = layer_module( @@ -679,9 +675,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2024,9 +2021,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): - module.gradient_checkpointing = value + module.gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = gradient_checkpointing_func is not None {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2312,18 +2310,12 @@ def forward( layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self.gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + output_attentions, ) else: layer_outputs = encoder_layer( @@ -2551,15 +2543,8 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self.gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, @@ -2567,6 +2552,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + output_attentions, + use_cache, ) else: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 34f5bae3746f03..7e1c471badf417 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -349,10 +349,24 @@ def test_gradient_checkpointing_enable_disable(self): model.gradient_checkpointing_enable() self.assertTrue(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to True + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertTrue( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True" + ) + # check disable works model.gradient_checkpointing_disable() self.assertFalse(model.is_gradient_checkpointing) + # Loop over all modules and check that relevant modules have gradient_checkpointing set to False + for n, m in model.named_modules(): + if hasattr(m, "gradient_checkpointing"): + self.assertFalse( + m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" + ) + def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: @@ -569,6 +583,13 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + model.gradient_checkpointing_disable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_attention_outputs(self): if not self.has_attentions: self.skipTest(reason="Model does not output attentions") From 0baa9246cb1ddac355db1df7824a521426599eb7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 25 Oct 2023 03:36:43 -0700 Subject: [PATCH 35/38] Fix TypicalLogitsWarper tensor OOB indexing edge case (#26579) * Fix TypicalLogitsWarper tensor OOB indexing edge case This can be triggerd fairly quickly with low precision e.g. bfloat16 and typical_p = 0.99. * Shift threshold index by one * Use explicit named arg for clamp min --- src/transformers/generation/logits_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index ed69a163778dd1..0ea1d2bdbae53a 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -492,8 +492,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative mass above the threshold - last_ind = (cumulative_probs < self.mass).sum(dim=1) - last_ind[last_ind < 0] = 0 + last_ind = (cumulative_probs < self.mass).sum(dim=1) - 1 + last_ind.clamp_(min=0) sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) From a64f8c1f878c2730a5b4048adf4b5c2dd7cabd8f Mon Sep 17 00:00:00 2001 From: Jing Hua Date: Thu, 26 Oct 2023 00:09:04 +0800 Subject: [PATCH 36/38] [docstring] fix incorrect llama docstring: encoder -> decoder (#27071) fix incorrect docstring: encoder -> decoder --- src/transformers/models/llama/configuration_llama.py | 4 ++-- src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index f37e52ca9ca539..7cce21671e9786 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -47,9 +47,9 @@ class LlamaConfig(PretrainedConfig): intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. + Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. + Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 279884dc164f04..207302a10e1683 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -871,7 +871,7 @@ def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. From ba073ea9e34a43f0ca02c77f3b92b5480a163cc7 Mon Sep 17 00:00:00 2001 From: Akash Kundu <112017800+Akash190104@users.noreply.github.com> Date: Wed, 25 Oct 2023 22:51:13 +0530 Subject: [PATCH 37/38] [DOCS] minor fixes in README.md (#27048) minor fixes --- README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 20813d78d3428a..dd14101f438506 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ limitations under the License. These models can be applied on: -* 📝 Text, for tasks like text classification, information extraction, question answering, summarization, translation, text generation, in over 100 languages. +* 📝 Text, for tasks like text classification, information extraction, question answering, summarization, translation, and text generation, in over 100 languages. * 🖼️ Images, for tasks like image classification, object detection, and segmentation. * 🗣️ Audio, for tasks like speech recognition and audio classification. @@ -147,7 +147,7 @@ To immediately use a model on a given input (text, image, audio, ...), we provid [{'label': 'POSITIVE', 'score': 0.9996980428695679}] ``` -The second line of code downloads and caches the pretrained model used by the pipeline, while the third evaluates it on the given text. Here the answer is "positive" with a confidence of 99.97%. +The second line of code downloads and caches the pretrained model used by the pipeline, while the third evaluates it on the given text. Here, the answer is "positive" with a confidence of 99.97%. Many tasks have a pre-trained `pipeline` ready to go, in NLP but also in computer vision and speech. For example, we can easily extract detected objects in an image: @@ -181,7 +181,7 @@ Many tasks have a pre-trained `pipeline` ready to go, in NLP but also in compute 'box': {'xmin': 345, 'ymin': 23, 'xmax': 640, 'ymax': 368}}] ``` -Here we get a list of objects detected in the image, with a box surrounding the object and a confidence score. Here is the original image on the left, with the predictions displayed on the right: +Here, we get a list of objects detected in the image, with a box surrounding the object and a confidence score. Here is the original image on the left, with the predictions displayed on the right:

@@ -212,7 +212,7 @@ And here is the equivalent code for TensorFlow: >>> outputs = model(**inputs) ``` -The tokenizer is responsible for all the preprocessing the pretrained model expects, and can be called directly on a single string (as in the above examples) or a list. It will output a dictionary that you can use in downstream code or simply directly pass to your model using the ** argument unpacking operator. +The tokenizer is responsible for all the preprocessing the pretrained model expects and can be called directly on a single string (as in the above examples) or a list. It will output a dictionary that you can use in downstream code or simply directly pass to your model using the ** argument unpacking operator. The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) or a [TensorFlow `tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) (depending on your backend) which you can use as usual. [This tutorial](https://huggingface.co/docs/transformers/training) explains how to integrate such a model into a classic PyTorch or TensorFlow training loop, or how to use our `Trainer` API to quickly fine-tune on a new dataset. @@ -232,7 +232,7 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta 1. Choose the right framework for every part of a model's lifetime: - Train state-of-the-art models in 3 lines of code. - Move a single model between TF2.0/PyTorch/JAX frameworks at will. - - Seamlessly pick the right framework for training, evaluation and production. + - Seamlessly pick the right framework for training, evaluation, and production. 1. Easily customize a model or an example to your needs: - We provide examples for each architecture to reproduce the results published by its original authors. @@ -249,13 +249,13 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta ### With pip -This repository is tested on Python 3.8+, Flax 0.4.1+, PyTorch 1.10+ and TensorFlow 2.6+. +This repository is tested on Python 3.8+, Flax 0.4.1+, PyTorch 1.10+, and TensorFlow 2.6+. You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). First, create a virtual environment with the version of Python you're going to use and activate it. -Then, you will need to install at least one of Flax, PyTorch or TensorFlow. +Then, you will need to install at least one of Flax, PyTorch, or TensorFlow. Please refer to [TensorFlow installation page](https://www.tensorflow.org/install/), [PyTorch installation page](https://pytorch.org/get-started/locally/#start-locally) and/or [Flax](https://github.com/google/flax#quick-install) and [Jax](https://github.com/google/jax#installation) installation pages regarding the specific installation command for your platform. When one of those backends has been installed, 🤗 Transformers can be installed using pip as follows: @@ -294,11 +294,11 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer)** (from MIT) released with the paper [AST: Audio Spectrogram Transformer](https://arxiv.org/abs/2104.01778) by Yuan Gong, Yu-An Chung, James Glass. 1. **[Autoformer](https://huggingface.co/docs/transformers/model_doc/autoformer)** (from Tsinghua University) released with the paper [Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting](https://arxiv.org/abs/2106.13008) by Haixu Wu, Jiehui Xu, Jianmin Wang, Mingsheng Long. 1. **[Bark](https://huggingface.co/docs/transformers/model_doc/bark)** (from Suno) released in the repository [suno-ai/bark](https://github.com/suno-ai/bark) by Suno AI team. -1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer. +1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, and Luke Zettlemoyer. 1. **[BARThez](https://huggingface.co/docs/transformers/model_doc/barthez)** (from École polytechnique) released with the paper [BARThez: a Skilled Pretrained French Sequence-to-Sequence Model](https://arxiv.org/abs/2010.12321) by Moussa Kamal Eddine, Antoine J.-P. Tixier, Michalis Vazirgiannis. 1. **[BARTpho](https://huggingface.co/docs/transformers/model_doc/bartpho)** (from VinAI Research) released with the paper [BARTpho: Pre-trained Sequence-to-Sequence Models for Vietnamese](https://arxiv.org/abs/2109.09701) by Nguyen Luong Tran, Duong Minh Le and Dat Quoc Nguyen. 1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei. -1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. +1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 1. **[BERT For Sequence Generation](https://huggingface.co/docs/transformers/model_doc/bert-generation)** (from Google) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[BERTweet](https://huggingface.co/docs/transformers/model_doc/bertweet)** (from VinAI Research) released with the paper [BERTweet: A pre-trained language model for English Tweets](https://aclanthology.org/2020.emnlp-demos.2/) by Dat Quoc Nguyen, Thanh Vu and Anh Tuan Nguyen. 1. **[BigBird-Pegasus](https://huggingface.co/docs/transformers/model_doc/bigbird_pegasus)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. @@ -520,7 +520,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. 1. **[YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu. 1. **[YOSO](https://huggingface.co/docs/transformers/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh. -1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. +1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedback before starting your PR. To check if each model has an implementation in Flax, PyTorch or TensorFlow, or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/docs/transformers/index#supported-frameworks). From c34c50cdc01ed5a5dc73a2f36466d07e6fe9fb4e Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:31:36 +0200 Subject: [PATCH 38/38] [`docs`] Add `MaskGenerationPipeline` in docs (#27063) * add `MaskGenerationPipeline` in docs * Update __init__.py * fix repo consistency and clarify docstring * add on check docstirngs * actually we do have a tf sam * oops --- docs/source/en/main_classes/pipelines.md | 6 ++++++ src/transformers/__init__.py | 2 ++ utils/check_docstrings.py | 1 + 3 files changed, 9 insertions(+) diff --git a/docs/source/en/main_classes/pipelines.md b/docs/source/en/main_classes/pipelines.md index 6a1175d1f6ca74..b105cb544ffcbb 100644 --- a/docs/source/en/main_classes/pipelines.md +++ b/docs/source/en/main_classes/pipelines.md @@ -481,6 +481,12 @@ Pipelines available for multimodal tasks include the following. - __call__ - all +### MaskGenerationPipeline + +[[autodoc]] MaskGenerationPipeline + - __call__ + - all + ### VisualQuestionAnsweringPipeline [[autodoc]] VisualQuestionAnsweringPipeline diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8001bb33c81408..90147df41f2694 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -672,6 +672,7 @@ "ImageToImagePipeline", "ImageToTextPipeline", "JsonPipelineDataFormat", + "MaskGenerationPipeline", "NerPipeline", "ObjectDetectionPipeline", "PipedPipelineDataFormat", @@ -4814,6 +4815,7 @@ ImageToImagePipeline, ImageToTextPipeline, JsonPipelineDataFormat, + MaskGenerationPipeline, NerPipeline, ObjectDetectionPipeline, PipedPipelineDataFormat, diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 0552988c7a4cc2..42223e586fe155 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -356,6 +356,7 @@ "M2M100Config", "M2M100Tokenizer", "MarkupLMProcessor", + "MaskGenerationPipeline", "MBart50TokenizerFast", "MBartConfig", "MCTCTFeatureExtractor",